diff --git a/.venv/lib/python3.11/site-packages/google/cloud/__pycache__/extended_operations_pb2.cpython-311.pyc b/.venv/lib/python3.11/site-packages/google/cloud/__pycache__/extended_operations_pb2.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f44f794223631256a093be85abf7b62a0cbea5ac Binary files /dev/null and b/.venv/lib/python3.11/site-packages/google/cloud/__pycache__/extended_operations_pb2.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/google/cloud/extended_operations.proto b/.venv/lib/python3.11/site-packages/google/cloud/extended_operations.proto new file mode 100644 index 0000000000000000000000000000000000000000..2f86c3745d64eeae4cb5355d5140246ee84fd30a --- /dev/null +++ b/.venv/lib/python3.11/site-packages/google/cloud/extended_operations.proto @@ -0,0 +1,150 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// This file contains custom annotations that are used by GAPIC generators to +// handle Long Running Operation methods (LRO) that are NOT compliant with +// https://google.aip.dev/151. These annotations are public for technical +// reasons only. Please DO NOT USE them in your protos. +syntax = "proto3"; + +package google.cloud; + +import "google/protobuf/descriptor.proto"; + +option go_package = "google.golang.org/genproto/googleapis/cloud/extendedops;extendedops"; +option java_multiple_files = true; +option java_outer_classname = "ExtendedOperationsProto"; +option java_package = "com.google.cloud"; +option objc_class_prefix = "GAPI"; + +// FieldOptions to match corresponding fields in the initial request, +// polling request and operation response messages. +// +// Example: +// +// In an API-specific operation message: +// +// message MyOperation { +// string http_error_message = 1 [(operation_field) = ERROR_MESSAGE]; +// int32 http_error_status_code = 2 [(operation_field) = ERROR_CODE]; +// string id = 3 [(operation_field) = NAME]; +// Status status = 4 [(operation_field) = STATUS]; +// } +// +// In a polling request message (the one which is used to poll for an LRO +// status): +// +// message MyPollingRequest { +// string operation = 1 [(operation_response_field) = "id"]; +// string project = 2; +// string region = 3; +// } +// +// In an initial request message (the one which starts an LRO): +// +// message MyInitialRequest { +// string my_project = 2 [(operation_request_field) = "project"]; +// string my_region = 3 [(operation_request_field) = "region"]; +// } +// +extend google.protobuf.FieldOptions { + // A field annotation that maps fields in an API-specific Operation object to + // their standard counterparts in google.longrunning.Operation. See + // OperationResponseMapping enum definition. + OperationResponseMapping operation_field = 1149; + + // A field annotation that maps fields in the initial request message + // (the one which started the LRO) to their counterparts in the polling + // request message. For non-standard LRO, the polling response may be missing + // some of the information needed to make a subsequent polling request. The + // missing information (for example, project or region ID) is contained in the + // fields of the initial request message that this annotation must be applied + // to. The string value of the annotation corresponds to the name of the + // counterpart field in the polling request message that the annotated field's + // value will be copied to. + string operation_request_field = 1150; + + // A field annotation that maps fields in the polling request message to their + // counterparts in the initial and/or polling response message. The initial + // and the polling methods return an API-specific Operation object. Some of + // the fields from that response object must be reused in the subsequent + // request (like operation name/ID) to fully identify the polled operation. + // This annotation must be applied to the fields in the polling request + // message, the string value of the annotation must correspond to the name of + // the counterpart field in the Operation response object whose value will be + // copied to the annotated field. + string operation_response_field = 1151; +} + +// MethodOptions to identify the actual service and method used for operation +// status polling. +// +// Example: +// +// In a method, which starts an LRO: +// +// service MyService { +// rpc Foo(MyInitialRequest) returns (MyOperation) { +// option (operation_service) = "MyPollingService"; +// } +// } +// +// In a polling method: +// +// service MyPollingService { +// rpc Get(MyPollingRequest) returns (MyOperation) { +// option (operation_polling_method) = true; +// } +// } +extend google.protobuf.MethodOptions { + // A method annotation that maps an LRO method (the one which starts an LRO) + // to the service, which will be used to poll for the operation status. The + // annotation must be applied to the method which starts an LRO, the string + // value of the annotation must correspond to the name of the service used to + // poll for the operation status. + string operation_service = 1249; + + // A method annotation that marks methods that can be used for polling + // operation status (e.g. the MyPollingService.Get(MyPollingRequest) method). + bool operation_polling_method = 1250; +} + +// An enum to be used to mark the essential (for polling) fields in an +// API-specific Operation object. A custom Operation object may contain many +// different fields, but only few of them are essential to conduct a successful +// polling process. +enum OperationResponseMapping { + // Do not use. + UNDEFINED = 0; + + // A field in an API-specific (custom) Operation object which carries the same + // meaning as google.longrunning.Operation.name. + NAME = 1; + + // A field in an API-specific (custom) Operation object which carries the same + // meaning as google.longrunning.Operation.done. If the annotated field is of + // an enum type, `annotated_field_name == EnumType.DONE` semantics should be + // equivalent to `Operation.done == true`. If the annotated field is of type + // boolean, then it should follow the same semantics as Operation.done. + // Otherwise, a non-empty value should be treated as `Operation.done == true`. + STATUS = 2; + + // A field in an API-specific (custom) Operation object which carries the same + // meaning as google.longrunning.Operation.error.code. + ERROR_CODE = 3; + + // A field in an API-specific (custom) Operation object which carries the same + // meaning as google.longrunning.Operation.error.message. + ERROR_MESSAGE = 4; +} \ No newline at end of file diff --git a/.venv/lib/python3.11/site-packages/google/cloud/extended_operations_pb2.py b/.venv/lib/python3.11/site-packages/google/cloud/extended_operations_pb2.py new file mode 100644 index 0000000000000000000000000000000000000000..4ac8be1b2b548ee567343ef554b9fe00f839920b --- /dev/null +++ b/.venv/lib/python3.11/site-packages/google/cloud/extended_operations_pb2.py @@ -0,0 +1,47 @@ +# -*- coding: utf-8 -*- + +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: google/cloud/extended_operations.proto +"""Generated protocol buffer code.""" +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import symbol_database as _symbol_database +from google.protobuf.internal import builder as _builder + +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + +from google.protobuf import descriptor_pb2 as google_dot_protobuf_dot_descriptor__pb2 + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( + b"\n&google/cloud/extended_operations.proto\x12\x0cgoogle.cloud\x1a google/protobuf/descriptor.proto*b\n\x18OperationResponseMapping\x12\r\n\tUNDEFINED\x10\x00\x12\x08\n\x04NAME\x10\x01\x12\n\n\x06STATUS\x10\x02\x12\x0e\n\nERROR_CODE\x10\x03\x12\x11\n\rERROR_MESSAGE\x10\x04:_\n\x0foperation_field\x12\x1d.google.protobuf.FieldOptions\x18\xfd\x08 \x01(\x0e\x32&.google.cloud.OperationResponseMapping:?\n\x17operation_request_field\x12\x1d.google.protobuf.FieldOptions\x18\xfe\x08 \x01(\t:@\n\x18operation_response_field\x12\x1d.google.protobuf.FieldOptions\x18\xff\x08 \x01(\t::\n\x11operation_service\x12\x1e.google.protobuf.MethodOptions\x18\xe1\t \x01(\t:A\n\x18operation_polling_method\x12\x1e.google.protobuf.MethodOptions\x18\xe2\t \x01(\x08\x42y\n\x10\x63om.google.cloudB\x17\x45xtendedOperationsProtoP\x01ZCgoogle.golang.org/genproto/googleapis/cloud/extendedops;extendedops\xa2\x02\x04GAPIb\x06proto3" +) + +_globals = globals() +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) +_builder.BuildTopDescriptorsAndMessages( + DESCRIPTOR, "google.cloud.extended_operations_pb2", _globals +) +if _descriptor._USE_C_DESCRIPTORS == False: + DESCRIPTOR._options = None + DESCRIPTOR._serialized_options = b"\n\020com.google.cloudB\027ExtendedOperationsProtoP\001ZCgoogle.golang.org/genproto/googleapis/cloud/extendedops;extendedops\242\002\004GAPI" + _globals["_OPERATIONRESPONSEMAPPING"]._serialized_start = 90 + _globals["_OPERATIONRESPONSEMAPPING"]._serialized_end = 188 +# @@protoc_insertion_point(module_scope) diff --git a/.venv/lib/python3.11/site-packages/google/cloud/location/__pycache__/locations_pb2.cpython-311.pyc b/.venv/lib/python3.11/site-packages/google/cloud/location/__pycache__/locations_pb2.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9419312b6ab60d75310defe70abeed6b602b3437 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/google/cloud/location/__pycache__/locations_pb2.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/google/cloud/location/locations.proto b/.venv/lib/python3.11/site-packages/google/cloud/location/locations.proto new file mode 100644 index 0000000000000000000000000000000000000000..12de2aeb477b4c969588ffdcab5f58ab5b2269b0 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/google/cloud/location/locations.proto @@ -0,0 +1,108 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto3"; + +package google.cloud.location; + +import "google/api/annotations.proto"; +import "google/protobuf/any.proto"; +import "google/api/client.proto"; + +option cc_enable_arenas = true; +option go_package = "google.golang.org/genproto/googleapis/cloud/location;location"; +option java_multiple_files = true; +option java_outer_classname = "LocationsProto"; +option java_package = "com.google.cloud.location"; + +// An abstract interface that provides location-related information for +// a service. Service-specific metadata is provided through the +// [Location.metadata][google.cloud.location.Location.metadata] field. +service Locations { + option (google.api.default_host) = "cloud.googleapis.com"; + option (google.api.oauth_scopes) = "https://www.googleapis.com/auth/cloud-platform"; + + // Lists information about the supported locations for this service. + rpc ListLocations(ListLocationsRequest) returns (ListLocationsResponse) { + option (google.api.http) = { + get: "/v1/{name=locations}" + additional_bindings { + get: "/v1/{name=projects/*}/locations" + } + }; + } + + // Gets information about a location. + rpc GetLocation(GetLocationRequest) returns (Location) { + option (google.api.http) = { + get: "/v1/{name=locations/*}" + additional_bindings { + get: "/v1/{name=projects/*/locations/*}" + } + }; + } +} + +// The request message for [Locations.ListLocations][google.cloud.location.Locations.ListLocations]. +message ListLocationsRequest { + // The resource that owns the locations collection, if applicable. + string name = 1; + + // The standard list filter. + string filter = 2; + + // The standard list page size. + int32 page_size = 3; + + // The standard list page token. + string page_token = 4; +} + +// The response message for [Locations.ListLocations][google.cloud.location.Locations.ListLocations]. +message ListLocationsResponse { + // A list of locations that matches the specified filter in the request. + repeated Location locations = 1; + + // The standard List next-page token. + string next_page_token = 2; +} + +// The request message for [Locations.GetLocation][google.cloud.location.Locations.GetLocation]. +message GetLocationRequest { + // Resource name for the location. + string name = 1; +} + +// A resource that represents Google Cloud Platform location. +message Location { + // Resource name for the location, which may vary between implementations. + // For example: `"projects/example-project/locations/us-east1"` + string name = 1; + + // The canonical id for this location. For example: `"us-east1"`. + string location_id = 4; + + // The friendly name for this location, typically a nearby city name. + // For example, "Tokyo". + string display_name = 5; + + // Cross-service attributes for the location. For example + // + // {"cloud.googleapis.com/region": "us-east1"} + map labels = 2; + + // Service-specific metadata. For example the available capacity at the given + // location. + google.protobuf.Any metadata = 3; +} diff --git a/.venv/lib/python3.11/site-packages/google/cloud/location/locations_pb2.py b/.venv/lib/python3.11/site-packages/google/cloud/location/locations_pb2.py new file mode 100644 index 0000000000000000000000000000000000000000..17bc85892b2b7bf2ece6b261823831eb7caa77bb --- /dev/null +++ b/.venv/lib/python3.11/site-packages/google/cloud/location/locations_pb2.py @@ -0,0 +1,71 @@ +# -*- coding: utf-8 -*- + +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: google/cloud/location/locations.proto +"""Generated protocol buffer code.""" +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import symbol_database as _symbol_database +from google.protobuf.internal import builder as _builder + +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + +from google.api import annotations_pb2 as google_dot_api_dot_annotations__pb2 +from google.protobuf import any_pb2 as google_dot_protobuf_dot_any__pb2 +from google.api import client_pb2 as google_dot_api_dot_client__pb2 + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( + b'\n%google/cloud/location/locations.proto\x12\x15google.cloud.location\x1a\x1cgoogle/api/annotations.proto\x1a\x19google/protobuf/any.proto\x1a\x17google/api/client.proto"[\n\x14ListLocationsRequest\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x0e\n\x06\x66ilter\x18\x02 \x01(\t\x12\x11\n\tpage_size\x18\x03 \x01(\x05\x12\x12\n\npage_token\x18\x04 \x01(\t"d\n\x15ListLocationsResponse\x12\x32\n\tlocations\x18\x01 \x03(\x0b\x32\x1f.google.cloud.location.Location\x12\x17\n\x0fnext_page_token\x18\x02 \x01(\t""\n\x12GetLocationRequest\x12\x0c\n\x04name\x18\x01 \x01(\t"\xd7\x01\n\x08Location\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x13\n\x0blocation_id\x18\x04 \x01(\t\x12\x14\n\x0c\x64isplay_name\x18\x05 \x01(\t\x12;\n\x06labels\x18\x02 \x03(\x0b\x32+.google.cloud.location.Location.LabelsEntry\x12&\n\x08metadata\x18\x03 \x01(\x0b\x32\x14.google.protobuf.Any\x1a-\n\x0bLabelsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\x32\xa4\x03\n\tLocations\x12\xab\x01\n\rListLocations\x12+.google.cloud.location.ListLocationsRequest\x1a,.google.cloud.location.ListLocationsResponse"?\x82\xd3\xe4\x93\x02\x39\x12\x14/v1/{name=locations}Z!\x12\x1f/v1/{name=projects/*}/locations\x12\x9e\x01\n\x0bGetLocation\x12).google.cloud.location.GetLocationRequest\x1a\x1f.google.cloud.location.Location"C\x82\xd3\xe4\x93\x02=\x12\x16/v1/{name=locations/*}Z#\x12!/v1/{name=projects/*/locations/*}\x1aH\xca\x41\x14\x63loud.googleapis.com\xd2\x41.https://www.googleapis.com/auth/cloud-platformBo\n\x19\x63om.google.cloud.locationB\x0eLocationsProtoP\x01Z=google.golang.org/genproto/googleapis/cloud/location;location\xf8\x01\x01\x62\x06proto3' +) + +_globals = globals() +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) +_builder.BuildTopDescriptorsAndMessages( + DESCRIPTOR, "google.cloud.location.locations_pb2", _globals +) +if _descriptor._USE_C_DESCRIPTORS == False: + DESCRIPTOR._options = None + DESCRIPTOR._serialized_options = b"\n\031com.google.cloud.locationB\016LocationsProtoP\001Z=google.golang.org/genproto/googleapis/cloud/location;location\370\001\001" + _LOCATION_LABELSENTRY._options = None + _LOCATION_LABELSENTRY._serialized_options = b"8\001" + _LOCATIONS._options = None + _LOCATIONS._serialized_options = b"\312A\024cloud.googleapis.com\322A.https://www.googleapis.com/auth/cloud-platform" + _LOCATIONS.methods_by_name["ListLocations"]._options = None + _LOCATIONS.methods_by_name[ + "ListLocations" + ]._serialized_options = b"\202\323\344\223\0029\022\024/v1/{name=locations}Z!\022\037/v1/{name=projects/*}/locations" + _LOCATIONS.methods_by_name["GetLocation"]._options = None + _LOCATIONS.methods_by_name[ + "GetLocation" + ]._serialized_options = b"\202\323\344\223\002=\022\026/v1/{name=locations/*}Z#\022!/v1/{name=projects/*/locations/*}" + _globals["_LISTLOCATIONSREQUEST"]._serialized_start = 146 + _globals["_LISTLOCATIONSREQUEST"]._serialized_end = 237 + _globals["_LISTLOCATIONSRESPONSE"]._serialized_start = 239 + _globals["_LISTLOCATIONSRESPONSE"]._serialized_end = 339 + _globals["_GETLOCATIONREQUEST"]._serialized_start = 341 + _globals["_GETLOCATIONREQUEST"]._serialized_end = 375 + _globals["_LOCATION"]._serialized_start = 378 + _globals["_LOCATION"]._serialized_end = 593 + _globals["_LOCATION_LABELSENTRY"]._serialized_start = 548 + _globals["_LOCATION_LABELSENTRY"]._serialized_end = 593 + _globals["_LOCATIONS"]._serialized_start = 596 + _globals["_LOCATIONS"]._serialized_end = 1016 +# @@protoc_insertion_point(module_scope) diff --git a/.venv/lib/python3.11/site-packages/google/generativeai/__init__.py b/.venv/lib/python3.11/site-packages/google/generativeai/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5b143d7683822f4cee43a883637ad30a6f7bbd1a --- /dev/null +++ b/.venv/lib/python3.11/site-packages/google/generativeai/__init__.py @@ -0,0 +1,86 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Google AI Python SDK + +## Setup + +```posix-terminal +pip install google-generativeai +``` + +## GenerativeModel + +Use `genai.GenerativeModel` to access the API: + +``` +import google.generativeai as genai +import os + +genai.configure(api_key=os.environ['API_KEY']) + +model = genai.GenerativeModel(model_name='gemini-1.5-flash') +response = model.generate_content('Teach me about how an LLM works') + +print(response.text) +``` + +See the [python quickstart](https://ai.google.dev/tutorials/python_quickstart) for more details. +""" +from __future__ import annotations + +from google.generativeai import version + +from google.generativeai import caching +from google.generativeai import protos +from google.generativeai import types + +from google.generativeai.client import configure + +from google.generativeai.embedding import embed_content +from google.generativeai.embedding import embed_content_async + +from google.generativeai.files import upload_file +from google.generativeai.files import get_file +from google.generativeai.files import list_files +from google.generativeai.files import delete_file + +from google.generativeai.generative_models import GenerativeModel +from google.generativeai.generative_models import ChatSession + +from google.generativeai.models import list_models +from google.generativeai.models import list_tuned_models + +from google.generativeai.models import get_model +from google.generativeai.models import get_base_model +from google.generativeai.models import get_tuned_model + +from google.generativeai.models import create_tuned_model +from google.generativeai.models import update_tuned_model +from google.generativeai.models import delete_tuned_model + +from google.generativeai.operations import list_operations +from google.generativeai.operations import get_operation + +from google.generativeai.types import GenerationConfig + +__version__ = version.__version__ + +del embedding +del files +del generative_models +del models +del client +del operations +del version diff --git a/.venv/lib/python3.11/site-packages/google/generativeai/answer.py b/.venv/lib/python3.11/site-packages/google/generativeai/answer.py new file mode 100644 index 0000000000000000000000000000000000000000..83bf5f679b00ef2edd1e112f8b5b2ca6bdc7bac8 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/google/generativeai/answer.py @@ -0,0 +1,363 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import dataclasses +from collections.abc import Iterable +import itertools +from typing import Any, Iterable, Union, Mapping, Optional +from typing_extensions import TypedDict + +import google.ai.generativelanguage as glm +from google.generativeai import protos + +from google.generativeai.client import ( + get_default_generative_client, + get_default_generative_async_client, +) +from google.generativeai.types import model_types +from google.generativeai.types import helper_types +from google.generativeai.types import safety_types +from google.generativeai.types import content_types +from google.generativeai.types import retriever_types +from google.generativeai.types.retriever_types import MetadataFilter + +DEFAULT_ANSWER_MODEL = "models/aqa" + +AnswerStyle = protos.GenerateAnswerRequest.AnswerStyle + +AnswerStyleOptions = Union[int, str, AnswerStyle] + +_ANSWER_STYLES: dict[AnswerStyleOptions, AnswerStyle] = { + AnswerStyle.ANSWER_STYLE_UNSPECIFIED: AnswerStyle.ANSWER_STYLE_UNSPECIFIED, + 0: AnswerStyle.ANSWER_STYLE_UNSPECIFIED, + "answer_style_unspecified": AnswerStyle.ANSWER_STYLE_UNSPECIFIED, + "unspecified": AnswerStyle.ANSWER_STYLE_UNSPECIFIED, + AnswerStyle.ABSTRACTIVE: AnswerStyle.ABSTRACTIVE, + 1: AnswerStyle.ABSTRACTIVE, + "answer_style_abstractive": AnswerStyle.ABSTRACTIVE, + "abstractive": AnswerStyle.ABSTRACTIVE, + AnswerStyle.EXTRACTIVE: AnswerStyle.EXTRACTIVE, + 2: AnswerStyle.EXTRACTIVE, + "answer_style_extractive": AnswerStyle.EXTRACTIVE, + "extractive": AnswerStyle.EXTRACTIVE, + AnswerStyle.VERBOSE: AnswerStyle.VERBOSE, + 3: AnswerStyle.VERBOSE, + "answer_style_verbose": AnswerStyle.VERBOSE, + "verbose": AnswerStyle.VERBOSE, +} + + +def to_answer_style(x: AnswerStyleOptions) -> AnswerStyle: + if isinstance(x, str): + x = x.lower() + return _ANSWER_STYLES[x] + + +GroundingPassageOptions = ( + Union[ + protos.GroundingPassage, tuple[str, content_types.ContentType], content_types.ContentType + ], +) + +GroundingPassagesOptions = Union[ + protos.GroundingPassages, + Iterable[GroundingPassageOptions], + Mapping[str, content_types.ContentType], +] + + +def _make_grounding_passages(source: GroundingPassagesOptions) -> protos.GroundingPassages: + """ + Converts the `source` into a `protos.GroundingPassage`. A `GroundingPassages` contains a list of + `protos.GroundingPassage` objects, which each contain a `protos.Content` and a string `id`. + + Args: + source: `Content` or a `GroundingPassagesOptions` that will be converted to protos.GroundingPassages. + + Return: + `protos.GroundingPassages` to be passed into `protos.GenerateAnswer`. + """ + if isinstance(source, protos.GroundingPassages): + return source + + if not isinstance(source, Iterable): + raise TypeError( + f"Invalid input: The 'source' argument must be an instance of 'GroundingPassagesOptions'. Received a '{type(source).__name__}' object instead." + ) + + passages = [] + if isinstance(source, Mapping): + source = source.items() + + for n, data in enumerate(source): + if isinstance(data, protos.GroundingPassage): + passages.append(data) + elif isinstance(data, tuple): + id, content = data # tuple must have exactly 2 items. + passages.append({"id": id, "content": content_types.to_content(content)}) + else: + passages.append({"id": str(n), "content": content_types.to_content(data)}) + + return protos.GroundingPassages(passages=passages) + + +SourceNameType = Union[ + str, retriever_types.Corpus, protos.Corpus, retriever_types.Document, protos.Document +] + + +class SemanticRetrieverConfigDict(TypedDict): + source: SourceNameType + query: content_types.ContentsType + metadata_filter: Optional[Iterable[MetadataFilter]] + max_chunks_count: Optional[int] + minimum_relevance_score: Optional[float] + + +SemanticRetrieverConfigOptions = Union[ + SourceNameType, + SemanticRetrieverConfigDict, + protos.SemanticRetrieverConfig, +] + + +def _maybe_get_source_name(source) -> str | None: + if isinstance(source, str): + return source + elif isinstance( + source, (retriever_types.Corpus, protos.Corpus, retriever_types.Document, protos.Document) + ): + return source.name + else: + return None + + +def _make_semantic_retriever_config( + source: SemanticRetrieverConfigOptions, + query: content_types.ContentsType, +) -> protos.SemanticRetrieverConfig: + if isinstance(source, protos.SemanticRetrieverConfig): + return source + + name = _maybe_get_source_name(source) + if name is not None: + source = {"source": name} + elif isinstance(source, dict): + source["source"] = _maybe_get_source_name(source["source"]) + else: + raise TypeError( + f"Invalid input: Failed to create a 'protos.SemanticRetrieverConfig' from the provided source. " + f"Received type: {type(source).__name__}, " + f"Received value: {source}" + ) + + if source["query"] is None: + source["query"] = query + elif isinstance(source["query"], str): + source["query"] = content_types.to_content(source["query"]) + + return protos.SemanticRetrieverConfig(source) + + +def _make_generate_answer_request( + *, + model: model_types.AnyModelNameOptions = DEFAULT_ANSWER_MODEL, + contents: content_types.ContentsType, + inline_passages: GroundingPassagesOptions | None = None, + semantic_retriever: SemanticRetrieverConfigOptions | None = None, + answer_style: AnswerStyle | None = None, + safety_settings: safety_types.SafetySettingOptions | None = None, + temperature: float | None = None, +) -> protos.GenerateAnswerRequest: + """ + constructs a protos.GenerateAnswerRequest object by organizing the input parameters for the API call to generate a grounded answer from the model. + + Args: + model: Name of the model used to generate the grounded response. + contents: Content of the current conversation with the model. For single-turn query, this is a + single question to answer. For multi-turn queries, this is a repeated field that contains + conversation history and the last `Content` in the list containing the question. + inline_passages: Grounding passages (a list of `Content`-like objects or `(id, content)` pairs, + or a `protos.GroundingPassages`) to send inline with the request. Exclusive with `semantic_retriever`, + one must be set, but not both. + semantic_retriever: A Corpus, Document, or `protos.SemanticRetrieverConfig` to use for grounding. Exclusive with + `inline_passages`, one must be set, but not both. + answer_style: Style for grounded answers. + safety_settings: Safety settings for generated output. + temperature: The temperature for randomness in the output. + + Returns: + Call for protos.GenerateAnswerRequest(). + """ + model = model_types.make_model_name(model) + + contents = content_types.to_contents(contents) + + if safety_settings: + safety_settings = safety_types.normalize_safety_settings(safety_settings) + + if inline_passages is not None and semantic_retriever is not None: + raise ValueError( + f"Invalid configuration: Please set either 'inline_passages' or 'semantic_retriever_config', but not both. " + f"Received for inline_passages: {inline_passages}, and for semantic_retriever: {semantic_retriever}." + ) + elif inline_passages is not None: + inline_passages = _make_grounding_passages(inline_passages) + elif semantic_retriever is not None: + semantic_retriever = _make_semantic_retriever_config(semantic_retriever, contents[-1]) + else: + raise TypeError( + f"Invalid configuration: Either 'inline_passages' or 'semantic_retriever_config' must be provided, but currently both are 'None'. " + f"Received for inline_passages: {inline_passages}, and for semantic_retriever: {semantic_retriever}." + ) + + if answer_style: + answer_style = to_answer_style(answer_style) + + return protos.GenerateAnswerRequest( + model=model, + contents=contents, + inline_passages=inline_passages, + semantic_retriever=semantic_retriever, + safety_settings=safety_settings, + temperature=temperature, + answer_style=answer_style, + ) + + +def generate_answer( + *, + model: model_types.AnyModelNameOptions = DEFAULT_ANSWER_MODEL, + contents: content_types.ContentsType, + inline_passages: GroundingPassagesOptions | None = None, + semantic_retriever: SemanticRetrieverConfigOptions | None = None, + answer_style: AnswerStyle | None = None, + safety_settings: safety_types.SafetySettingOptions | None = None, + temperature: float | None = None, + client: glm.GenerativeServiceClient | None = None, + request_options: helper_types.RequestOptionsType | None = None, +): + """Calls the GenerateAnswer API and returns a `types.Answer` containing the response. + + You can pass a literal list of text chunks: + + >>> from google.generativeai import answer + >>> answer.generate_answer( + ... content=question, + ... inline_passages=splitter.split(document) + ... ) + + Or pass a reference to a retreiver Document or Corpus: + + >>> from google.generativeai import answer + >>> from google.generativeai import retriever + >>> my_corpus = retriever.get_corpus('my_corpus') + >>> genai.generate_answer( + ... content=question, + ... semantic_retriever=my_corpus + ... ) + + + Args: + model: Which model to call, as a string or a `types.Model`. + contents: The question to be answered by the model, grounded in the + provided source. + inline_passages: Grounding passages (a list of `Content`-like objects or (id, content) pairs, + or a `protos.GroundingPassages`) to send inline with the request. Exclusive with `semantic_retriever`, + one must be set, but not both. + semantic_retriever: A Corpus, Document, or `protos.SemanticRetrieverConfig` to use for grounding. Exclusive with + `inline_passages`, one must be set, but not both. + answer_style: Style in which the grounded answer should be returned. + safety_settings: Safety settings for generated output. Defaults to None. + temperature: Controls the randomness of the output. + client: If you're not relying on a default client, you pass a `glm.GenerativeServiceClient` instead. + request_options: Options for the request. + + Returns: + A `types.Answer` containing the model's text answer response. + """ + if request_options is None: + request_options = {} + + if client is None: + client = get_default_generative_client() + + request = _make_generate_answer_request( + model=model, + contents=contents, + inline_passages=inline_passages, + semantic_retriever=semantic_retriever, + safety_settings=safety_settings, + temperature=temperature, + answer_style=answer_style, + ) + + response = client.generate_answer(request, **request_options) + + return response + + +async def generate_answer_async( + *, + model: model_types.AnyModelNameOptions = DEFAULT_ANSWER_MODEL, + contents: content_types.ContentsType, + inline_passages: GroundingPassagesOptions | None = None, + semantic_retriever: SemanticRetrieverConfigOptions | None = None, + answer_style: AnswerStyle | None = None, + safety_settings: safety_types.SafetySettingOptions | None = None, + temperature: float | None = None, + client: glm.GenerativeServiceClient | None = None, + request_options: helper_types.RequestOptionsType | None = None, +): + """ + Calls the API and returns a `types.Answer` containing the answer. + + Args: + model: Which model to call, as a string or a `types.Model`. + contents: The question to be answered by the model, grounded in the + provided source. + inline_passages: Grounding passages (a list of `Content`-like objects or (id, content) pairs, + or a `protos.GroundingPassages`) to send inline with the request. Exclusive with `semantic_retriever`, + one must be set, but not both. + semantic_retriever: A Corpus, Document, or `protos.SemanticRetrieverConfig` to use for grounding. Exclusive with + `inline_passages`, one must be set, but not both. + answer_style: Style in which the grounded answer should be returned. + safety_settings: Safety settings for generated output. Defaults to None. + temperature: Controls the randomness of the output. + client: If you're not relying on a default client, you pass a `glm.GenerativeServiceClient` instead. + + Returns: + A `types.Answer` containing the model's text answer response. + """ + if request_options is None: + request_options = {} + + if client is None: + client = get_default_generative_async_client() + + request = _make_generate_answer_request( + model=model, + contents=contents, + inline_passages=inline_passages, + semantic_retriever=semantic_retriever, + safety_settings=safety_settings, + temperature=temperature, + answer_style=answer_style, + ) + + response = await client.generate_answer(request, **request_options) + + return response diff --git a/.venv/lib/python3.11/site-packages/google/generativeai/audio_models/__init__.py b/.venv/lib/python3.11/site-packages/google/generativeai/audio_models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.11/site-packages/google/generativeai/audio_models/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/google/generativeai/audio_models/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c64029d4b969b5b0c8a42f31933c9010ab2bb1f9 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/google/generativeai/audio_models/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/google/generativeai/audio_models/__pycache__/_audio_models.cpython-311.pyc b/.venv/lib/python3.11/site-packages/google/generativeai/audio_models/__pycache__/_audio_models.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..df36d2053510c5bd04b0cdf7d042a9bbfc0b1831 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/google/generativeai/audio_models/__pycache__/_audio_models.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/google/generativeai/audio_models/_audio_models.py b/.venv/lib/python3.11/site-packages/google/generativeai/audio_models/_audio_models.py new file mode 100644 index 0000000000000000000000000000000000000000..c9e78f163a98f7aae51fd4cc551fa572b2dbad15 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/google/generativeai/audio_models/_audio_models.py @@ -0,0 +1,3 @@ +class SpeechGenerationModel: + def __init__(self, model_name): + self.model_name = model_name diff --git a/.venv/lib/python3.11/site-packages/google/generativeai/caching.py b/.venv/lib/python3.11/site-packages/google/generativeai/caching.py new file mode 100644 index 0000000000000000000000000000000000000000..9acad4726234d443de67427821fee82f632b5099 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/google/generativeai/caching.py @@ -0,0 +1,314 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import datetime +import textwrap +from typing import Iterable, Optional + +from google.generativeai import protos +from google.generativeai.types import caching_types +from google.generativeai.types import content_types +from google.generativeai.client import get_default_cache_client + +from google.protobuf import field_mask_pb2 + +_USER_ROLE = "user" +_MODEL_ROLE = "model" + + +class CachedContent: + """Cached content resource.""" + + def __init__(self, name): + """Fetches a `CachedContent` resource. + + Identical to `CachedContent.get`. + + Args: + name: The resource name referring to the cached content. + """ + client = get_default_cache_client() + + if "cachedContents/" not in name: + name = "cachedContents/" + name + + request = protos.GetCachedContentRequest(name=name) + response = client.get_cached_content(request) + self._proto = response + + @property + def name(self) -> str: + return self._proto.name + + @property + def model(self) -> str: + return self._proto.model + + @property + def display_name(self) -> str: + return self._proto.display_name + + @property + def usage_metadata(self) -> protos.CachedContent.UsageMetadata: + return self._proto.usage_metadata + + @property + def create_time(self) -> datetime.datetime: + return self._proto.create_time + + @property + def update_time(self) -> datetime.datetime: + return self._proto.update_time + + @property + def expire_time(self) -> datetime.datetime: + return self._proto.expire_time + + def __str__(self): + return textwrap.dedent( + f"""\ + CachedContent( + name='{self.name}', + model='{self.model}', + display_name='{self.display_name}', + usage_metadata={'{'} + 'total_token_count': {self.usage_metadata.total_token_count}, + {'}'}, + create_time={self.create_time}, + update_time={self.update_time}, + expire_time={self.expire_time} + )""" + ) + + __repr__ = __str__ + + @classmethod + def _from_obj(cls, obj: CachedContent | protos.CachedContent | dict) -> CachedContent: + """Creates an instance of CachedContent form an object, without calling `get`.""" + self = cls.__new__(cls) + self._proto = protos.CachedContent() + self._update(obj) + return self + + def _update(self, updates): + """Updates this instance inplace, does not call the API's `update` method""" + if isinstance(updates, CachedContent): + updates = updates._proto + + if not isinstance(updates, dict): + updates = type(updates).to_dict(updates, including_default_value_fields=False) + + for key, value in updates.items(): + setattr(self._proto, key, value) + + @staticmethod + def _prepare_create_request( + model: str, + *, + display_name: str | None = None, + system_instruction: Optional[content_types.ContentType] = None, + contents: Optional[content_types.ContentsType] = None, + tools: Optional[content_types.FunctionLibraryType] = None, + tool_config: Optional[content_types.ToolConfigType] = None, + ttl: Optional[caching_types.TTLTypes] = None, + expire_time: Optional[caching_types.ExpireTimeTypes] = None, + ) -> protos.CreateCachedContentRequest: + """Prepares a CreateCachedContentRequest.""" + if ttl and expire_time: + raise ValueError( + "Exclusive arguments: Please provide either `ttl` or `expire_time`, not both." + ) + + if "/" not in model: + model = "models/" + model + + if display_name and len(display_name) > 128: + raise ValueError("`display_name` must be no more than 128 unicode characters.") + + if system_instruction: + system_instruction = content_types.to_content(system_instruction) + + tools_lib = content_types.to_function_library(tools) + if tools_lib: + tools_lib = tools_lib.to_proto() + + if tool_config: + tool_config = content_types.to_tool_config(tool_config) + + if contents: + contents = content_types.to_contents(contents) + if not contents[-1].role: + contents[-1].role = _USER_ROLE + + ttl = caching_types.to_optional_ttl(ttl) + expire_time = caching_types.to_optional_expire_time(expire_time) + + cached_content = protos.CachedContent( + model=model, + display_name=display_name, + system_instruction=system_instruction, + contents=contents, + tools=tools_lib, + tool_config=tool_config, + ttl=ttl, + expire_time=expire_time, + ) + + return protos.CreateCachedContentRequest(cached_content=cached_content) + + @classmethod + def create( + cls, + model: str, + *, + display_name: str | None = None, + system_instruction: Optional[content_types.ContentType] = None, + contents: Optional[content_types.ContentsType] = None, + tools: Optional[content_types.FunctionLibraryType] = None, + tool_config: Optional[content_types.ToolConfigType] = None, + ttl: Optional[caching_types.TTLTypes] = None, + expire_time: Optional[caching_types.ExpireTimeTypes] = None, + ) -> CachedContent: + """Creates `CachedContent` resource. + + Args: + model: The name of the `model` to use for cached content creation. + Any `CachedContent` resource can be only used with the + `model` it was created for. + display_name: The user-generated meaningful display name + of the cached content. `display_name` must be no + more than 128 unicode characters. + system_instruction: Developer set system instruction. + contents: Contents to cache. + tools: A list of `Tools` the model may use to generate response. + tool_config: Config to apply to all tools. + ttl: TTL for cached resource (in seconds). Defaults to 1 hour. + `ttl` and `expire_time` are exclusive arguments. + expire_time: Expiration time for cached resource. + `ttl` and `expire_time` are exclusive arguments. + + Returns: + `CachedContent` resource with specified name. + """ + client = get_default_cache_client() + + request = cls._prepare_create_request( + model=model, + display_name=display_name, + system_instruction=system_instruction, + contents=contents, + tools=tools, + tool_config=tool_config, + ttl=ttl, + expire_time=expire_time, + ) + + response = client.create_cached_content(request) + result = CachedContent._from_obj(response) + return result + + @classmethod + def get(cls, name: str) -> CachedContent: + """Fetches required `CachedContent` resource. + + Args: + name: The resource name referring to the cached content. + + Returns: + `CachedContent` resource with specified `name`. + """ + client = get_default_cache_client() + + if "cachedContents/" not in name: + name = "cachedContents/" + name + + request = protos.GetCachedContentRequest(name=name) + response = client.get_cached_content(request) + result = CachedContent._from_obj(response) + return result + + @classmethod + def list(cls, page_size: Optional[int] = 1) -> Iterable[CachedContent]: + """Lists `CachedContent` objects associated with the project. + + Args: + page_size: The maximum number of permissions to return (per page). + The service may return fewer `CachedContent` objects. + + Returns: + A paginated list of `CachedContent` objects. + """ + client = get_default_cache_client() + + request = protos.ListCachedContentsRequest(page_size=page_size) + for cached_content in client.list_cached_contents(request): + cached_content = CachedContent._from_obj(cached_content) + yield cached_content + + def delete(self) -> None: + """Deletes `CachedContent` resource.""" + client = get_default_cache_client() + + request = protos.DeleteCachedContentRequest(name=self.name) + client.delete_cached_content(request) + return + + def update( + self, + *, + ttl: Optional[caching_types.TTLTypes] = None, + expire_time: Optional[caching_types.ExpireTimeTypes] = None, + ) -> None: + """Updates requested `CachedContent` resource. + + Args: + ttl: TTL for cached resource (in seconds). Defaults to 1 hour. + `ttl` and `expire_time` are exclusive arguments. + expire_time: Expiration time for cached resource. + `ttl` and `expire_time` are exclusive arguments. + """ + client = get_default_cache_client() + + if ttl and expire_time: + raise ValueError( + "Exclusive arguments: Please provide either `ttl` or `expire_time`, not both." + ) + + ttl = caching_types.to_optional_ttl(ttl) + expire_time = caching_types.to_optional_expire_time(expire_time) + + updates = protos.CachedContent( + name=self.name, + ttl=ttl, + expire_time=expire_time, + ) + + field_mask = field_mask_pb2.FieldMask() + + if ttl: + field_mask.paths.append("ttl") + elif expire_time: + field_mask.paths.append("expire_time") + else: + raise ValueError( + f"Bad update name: Only `ttl` or `expire_time` can be updated for `CachedContent`." + ) + + request = protos.UpdateCachedContentRequest(cached_content=updates, update_mask=field_mask) + updated_cc = client.update_cached_content(request) + self._update(updated_cc) + + return diff --git a/.venv/lib/python3.11/site-packages/google/generativeai/client.py b/.venv/lib/python3.11/site-packages/google/generativeai/client.py new file mode 100644 index 0000000000000000000000000000000000000000..c9c5c8c5b0909d562416c9b49ebcc8f00aa854b1 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/google/generativeai/client.py @@ -0,0 +1,388 @@ +from __future__ import annotations + +import os +import contextlib +import inspect +import dataclasses +import pathlib +import threading +from typing import Any, cast +from collections.abc import Sequence +import httplib2 +from io import IOBase + +import google.ai.generativelanguage as glm +import google.generativeai.protos as protos + +from google.auth import credentials as ga_credentials +from google.auth import exceptions as ga_exceptions +from google import auth +from google.api_core import client_options as client_options_lib +from google.api_core import gapic_v1 +from google.api_core import operations_v1 + +import googleapiclient.http +import googleapiclient.discovery + +try: + from google.generativeai import version + + __version__ = version.__version__ +except ImportError: + __version__ = "0.0.0" + +USER_AGENT = "genai-py" + +#### Caution! #### +# - It would make sense for the discovery URL to respect the client_options.endpoint setting. +# - That would make testing Files on the staging server possible. +# - We tried fixing this once, but broke colab in the process because their endpoint didn't forward the discovery +# requests. https://github.com/google-gemini/generative-ai-python/pull/333 +# - Kaggle would have a similar problem (b/362278209). +# - I think their proxy would forward the discovery traffic. +# - But they don't need to intercept the files-service at all, and uploads of large files could overload them. +# - Do the scotty uploads go to the same domain? +# - If you do route the discovery call to kaggle, be sure to attach the default_metadata (they need it). +# - One solution to all this would be if configure could take overrides per service. +# - set client_options.endpoint, but use a different endpoint for file service? It's not clear how best to do that +# through the file service. +################## +GENAI_API_DISCOVERY_URL = "https://generativelanguage.googleapis.com/$discovery/rest" + + +@contextlib.contextmanager +def patch_colab_gce_credentials(): + get_gce = auth._default._get_gce_credentials + if "COLAB_RELEASE_TAG" in os.environ: + auth._default._get_gce_credentials = lambda *args, **kwargs: (None, None) + + try: + yield + finally: + auth._default._get_gce_credentials = get_gce + + +class FileServiceClient(glm.FileServiceClient): + def __init__(self, *args, **kwargs): + self._discovery_api = None + self._local = threading.local() + super().__init__(*args, **kwargs) + + def _setup_discovery_api(self, metadata: dict | Sequence[tuple[str, str]] = ()): + api_key = self._client_options.api_key + if api_key is None: + raise ValueError( + "Invalid operation: Uploading to the File API requires an API key. Please provide a valid API key." + ) + + request = googleapiclient.http.HttpRequest( + http=httplib2.Http(), + postproc=lambda resp, content: (resp, content), + uri=f"{GENAI_API_DISCOVERY_URL}?version=v1beta&key={api_key}", + headers=dict(metadata), + ) + response, content = request.execute() + request.http.close() + + discovery_doc = content.decode("utf-8") + self._local.discovery_api = googleapiclient.discovery.build_from_document( + discovery_doc, developerKey=api_key + ) + + def create_file( + self, + path: str | pathlib.Path | os.PathLike | IOBase, + *, + mime_type: str | None = None, + name: str | None = None, + display_name: str | None = None, + resumable: bool = True, + metadata: Sequence[tuple[str, str]] = (), + ) -> protos.File: + if self._discovery_api is None: + self._setup_discovery_api(metadata) + + file = {} + if name is not None: + file["name"] = name + if display_name is not None: + file["displayName"] = display_name + + if isinstance(path, IOBase): + media = googleapiclient.http.MediaIoBaseUpload( + fd=path, mimetype=mime_type, resumable=resumable + ) + else: + media = googleapiclient.http.MediaFileUpload( + filename=path, mimetype=mime_type, resumable=resumable + ) + + request = self._local.discovery_api.media().upload(body={"file": file}, media_body=media) + for key, value in metadata: + request.headers[key] = value + result = request.execute() + + return self.get_file({"name": result["file"]["name"]}) + + +class FileServiceAsyncClient(glm.FileServiceAsyncClient): + async def create_file(self, *args, **kwargs): + raise NotImplementedError( + "The `create_file` method is currently not supported for the asynchronous client." + ) + + +@dataclasses.dataclass +class _ClientManager: + client_config: dict[str, Any] = dataclasses.field(default_factory=dict) + default_metadata: Sequence[tuple[str, str]] = () + clients: dict[str, Any] = dataclasses.field(default_factory=dict) + + def configure( + self, + *, + api_key: str | None = None, + credentials: ga_credentials.Credentials | dict | None = None, + # The user can pass a string to choose `rest` or `grpc` or 'grpc_asyncio'. + # See _transport_registry in the google.ai.generativelanguage package. + # Since the transport classes align with the client classes it wouldn't make + # sense to accept a `Transport` object here even though the client classes can. + # We could accept a dict since all the `Transport` classes take the same args, + # but that seems rare. Users that need it can just switch to the low level API. + transport: str | None = None, + client_options: client_options_lib.ClientOptions | dict[str, Any] | None = None, + client_info: gapic_v1.client_info.ClientInfo | None = None, + default_metadata: Sequence[tuple[str, str]] = (), + ) -> None: + """Initializes default client configurations using specified parameters or environment variables. + + If no API key has been provided (either directly, or on `client_options`) and the + `GEMINI_API_KEY` environment variable is set, it will be used as the API key. If not, + if the `GOOGLE_API_KEY` environement variable is set, it will be used as the API key. + + Note: Not all arguments are detailed below. Refer to the `*ServiceClient` classes in + `google.ai.generativelanguage` for details on the other arguments. + + Args: + transport: A string, one of: [`rest`, `grpc`, `grpc_asyncio`]. + api_key: The API-Key to use when creating the default clients (each service uses + a separate client). This is a shortcut for `client_options={"api_key": api_key}`. + If omitted, and the `GEMINI_API_KEY` or the `GOOGLE_API_KEY` environment variable + are set, they will be used in this order of priority. + default_metadata: Default (key, value) metadata pairs to send with every request. + when using `transport="rest"` these are sent as HTTP headers. + """ + if isinstance(client_options, dict): + client_options = client_options_lib.from_dict(client_options) + if client_options is None: + client_options = client_options_lib.ClientOptions() + client_options = cast(client_options_lib.ClientOptions, client_options) + had_api_key_value = getattr(client_options, "api_key", None) + + if had_api_key_value: + if api_key is not None: + raise ValueError( + "Invalid configuration: Please set either `api_key` or `client_options['api_key']`, but not both." + ) + else: + if api_key is None: + # If no key is provided explicitly, attempt to load one from the + # environment. + api_key = os.getenv("GEMINI_API_KEY") + + if api_key is None: + # If the GEMINI_API_KEY doesn't exist, attempt to load the + # GOOGLE_API_KEY from the environment. + api_key = os.getenv("GOOGLE_API_KEY") + + client_options.api_key = api_key + + user_agent = f"{USER_AGENT}/{__version__}" + if client_info: + # Be respectful of any existing agent setting. + if client_info.user_agent: + client_info.user_agent += f" {user_agent}" + else: + client_info.user_agent = user_agent + else: + client_info = gapic_v1.client_info.ClientInfo(user_agent=user_agent) + + client_config = { + "credentials": credentials, + "transport": transport, + "client_options": client_options, + "client_info": client_info, + } + + client_config = {key: value for key, value in client_config.items() if value is not None} + + self.client_config = client_config + self.default_metadata = default_metadata + + self.clients = {} + + def make_client(self, name): + if name == "file": + cls = FileServiceClient + elif name == "file_async": + cls = FileServiceAsyncClient + elif name.endswith("_async"): + name = name.split("_")[0] + cls = getattr(glm, name.title() + "ServiceAsyncClient") + else: + cls = getattr(glm, name.title() + "ServiceClient") + + # Attempt to configure using defaults. + if not self.client_config: + configure() + + try: + with patch_colab_gce_credentials(): + client = cls(**self.client_config) + except ga_exceptions.DefaultCredentialsError as e: + e.args = ( + "\n No API_KEY or ADC found. Please either:\n" + " - Set the `GOOGLE_API_KEY` environment variable.\n" + " - Manually pass the key with `genai.configure(api_key=my_api_key)`.\n" + " - Or set up Application Default Credentials, see https://ai.google.dev/gemini-api/docs/oauth for more information.", + ) + raise e + + if not self.default_metadata: + return client + + def keep(name, f): + if name.startswith("_"): + return False + + if not callable(f): + return False + + if "metadata" not in inspect.signature(f).parameters.keys(): + return False + + return True + + def add_default_metadata_wrapper(f): + def call(*args, metadata=(), **kwargs): + metadata = list(metadata) + list(self.default_metadata) + return f(*args, **kwargs, metadata=metadata) + + return call + + for name, value in inspect.getmembers(cls): + if not keep(name, value): + continue + f = getattr(client, name) + f = add_default_metadata_wrapper(f) + setattr(client, name, f) + + return client + + def get_default_client(self, name): + name = name.lower() + if name == "operations": + return self.get_default_operations_client() + + client = self.clients.get(name) + if client is None: + client = self.make_client(name) + self.clients[name] = client + return client + + def get_default_operations_client(self) -> operations_v1.OperationsClient: + client = self.clients.get("operations", None) + if client is None: + model_client = self.get_default_client("Model") + client = model_client._transport.operations_client + self.clients["operations"] = client + return client + + +def configure( + *, + api_key: str | None = None, + credentials: ga_credentials.Credentials | dict | None = None, + # The user can pass a string to choose `rest` or `grpc` or 'grpc_asyncio'. + # Since the transport classes align with the client classes it wouldn't make + # sense to accept a `Transport` object here even though the client classes can. + # We could accept a dict since all the `Transport` classes take the same args, + # but that seems rare. Users that need it can just switch to the low level API. + transport: str | None = None, + client_options: client_options_lib.ClientOptions | dict | None = None, + client_info: gapic_v1.client_info.ClientInfo | None = None, + default_metadata: Sequence[tuple[str, str]] = (), +): + """Captures default client configuration. + + If no API key has been provided (either directly, or on `client_options`) and the + `GOOGLE_API_KEY` environment variable is set, it will be used as the API key. + + Note: Not all arguments are detailed below. Refer to the `*ServiceClient` classes in + `google.ai.generativelanguage` for details on the other arguments. + + Args: + transport: A string, one of: [`rest`, `grpc`, `grpc_asyncio`]. + api_key: The API-Key to use when creating the default clients (each service uses + a separate client). This is a shortcut for `client_options={"api_key": api_key}`. + If omitted, and the `GOOGLE_API_KEY` environment variable is set, it will be + used. + default_metadata: Default (key, value) metadata pairs to send with every request. + when using `transport="rest"` these are sent as HTTP headers. + """ + return _client_manager.configure( + api_key=api_key, + credentials=credentials, + transport=transport, + client_options=client_options, + client_info=client_info, + default_metadata=default_metadata, + ) + + +_client_manager = _ClientManager() +_client_manager.configure() + + +def get_default_cache_client() -> glm.CacheServiceClient: + return _client_manager.get_default_client("cache") + + +def get_default_file_client() -> glm.FilesServiceClient: + return _client_manager.get_default_client("file") + + +def get_default_file_async_client() -> glm.FilesServiceAsyncClient: + return _client_manager.get_default_client("file_async") + + +def get_default_generative_client() -> glm.GenerativeServiceClient: + return _client_manager.get_default_client("generative") + + +def get_default_generative_async_client() -> glm.GenerativeServiceAsyncClient: + return _client_manager.get_default_client("generative_async") + + +def get_default_operations_client() -> operations_v1.OperationsClient: + return _client_manager.get_default_client("operations") + + +def get_default_model_client() -> glm.ModelServiceAsyncClient: + return _client_manager.get_default_client("model") + + +def get_default_retriever_client() -> glm.RetrieverClient: + return _client_manager.get_default_client("retriever") + + +def get_default_retriever_async_client() -> glm.RetrieverAsyncClient: + return _client_manager.get_default_client("retriever_async") + + +def get_default_permission_client() -> glm.PermissionServiceClient: + return _client_manager.get_default_client("permission") + + +def get_default_permission_async_client() -> glm.PermissionServiceAsyncClient: + return _client_manager.get_default_client("permission_async") diff --git a/.venv/lib/python3.11/site-packages/google/generativeai/discuss.py b/.venv/lib/python3.11/site-packages/google/generativeai/discuss.py new file mode 100644 index 0000000000000000000000000000000000000000..0cc342096ae9fa223514702c6c3e7ba8d558d252 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/google/generativeai/discuss.py @@ -0,0 +1,590 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import dataclasses +import sys +import textwrap + +from typing import Any, Iterable, List, Optional, Union + +import google.ai.generativelanguage as glm + +from google.generativeai.client import get_default_discuss_client +from google.generativeai.client import get_default_discuss_async_client +from google.generativeai import string_utils +from google.generativeai.types import discuss_types +from google.generativeai.types import model_types +from google.generativeai.types import safety_types + + +def _make_message(content: discuss_types.MessageOptions) -> glm.Message: + """Creates a `glm.Message` object from the provided content.""" + if isinstance(content, glm.Message): + return content + if isinstance(content, str): + return glm.Message(content=content) + else: + return glm.Message(content) + + +def _make_messages( + messages: discuss_types.MessagesOptions, +) -> List[glm.Message]: + """ + Creates a list of `glm.Message` objects from the provided messages. + + This function takes a variety of message content inputs, such as strings, dictionaries, + or `glm.Message` objects, and creates a list of `glm.Message` objects. It ensures that + the authors of the messages alternate appropriately. If authors are not provided, + default authors are assigned based on their position in the list. + + Args: + messages: The messages to convert. + + Returns: + A list of `glm.Message` objects with alternating authors. + """ + if isinstance(messages, (str, dict, glm.Message)): + messages = [_make_message(messages)] + else: + messages = [_make_message(message) for message in messages] + + even_authors = set(msg.author for msg in messages[::2] if msg.author) + if not even_authors: + even_author = "0" + elif len(even_authors) == 1: + even_author = even_authors.pop() + else: + raise discuss_types.AuthorError("Authors are not strictly alternating") + + odd_authors = set(msg.author for msg in messages[1::2] if msg.author) + if not odd_authors: + odd_author = "1" + elif len(odd_authors) == 1: + odd_author = odd_authors.pop() + else: + raise discuss_types.AuthorError("Authors are not strictly alternating") + + if all(msg.author for msg in messages): + return messages + + authors = [even_author, odd_author] + for i, msg in enumerate(messages): + msg.author = authors[i % 2] + + return messages + + +def _make_example(item: discuss_types.ExampleOptions) -> glm.Example: + """Creates a `glm.Example` object from the provided item.""" + if isinstance(item, glm.Example): + return item + + if isinstance(item, dict): + item = item.copy() + item["input"] = _make_message(item["input"]) + item["output"] = _make_message(item["output"]) + return glm.Example(item) + + if isinstance(item, Iterable): + input, output = list(item) + return glm.Example(input=_make_message(input), output=_make_message(output)) + + # try anyway + return glm.Example(item) + + +def _make_examples_from_flat( + examples: List[discuss_types.MessageOptions], +) -> List[glm.Example]: + """ + Creates a list of `glm.Example` objects from a list of message options. + + This function takes a list of `discuss_types.MessageOptions` and pairs them into + `glm.Example` objects. The input examples must be in pairs to create valid examples. + + Args: + examples: The list of `discuss_types.MessageOptions`. + + Returns: + A list of `glm.Example objects` created by pairing up the provided messages. + + Raises: + ValueError: If the provided list of examples is not of even length. + """ + if len(examples) % 2 != 0: + raise ValueError( + textwrap.dedent( + f"""\ + You must pass `Primer` objects, pairs of messages, or an *even* number of messages, got: + {len(examples)} messages""" + ) + ) + result = [] + pair = [] + for n, item in enumerate(examples): + msg = _make_message(item) + pair.append(msg) + if n % 2 == 0: + continue + primer = glm.Example( + input=pair[0], + output=pair[1], + ) + result.append(primer) + pair = [] + return result + + +def _make_examples( + examples: discuss_types.ExamplesOptions, +) -> List[glm.Example]: + """ + Creates a list of `glm.Example` objects from the provided examples. + + This function takes various types of example content inputs and creates a list + of `glm.Example` objects. It handles the conversion of different input types and ensures + the appropriate structure for creating valid examples. + + Args: + examples: The examples to convert. + + Returns: + A list of `glm.Example` objects created from the provided examples. + """ + if isinstance(examples, glm.Example): + return [examples] + + if isinstance(examples, dict): + return [_make_example(examples)] + + examples = list(examples) + + if not examples: + return examples + + first = examples[0] + + if isinstance(first, dict): + if "content" in first: + # These are `Messages` + return _make_examples_from_flat(examples) + else: + if not ("input" in first and "output" in first): + raise TypeError( + "To create an `Example` from a dict you must supply both `input` and an `output` keys" + ) + else: + if isinstance(first, discuss_types.MESSAGE_OPTIONS): + return _make_examples_from_flat(examples) + + result = [] + for item in examples: + result.append(_make_example(item)) + return result + + +def _make_message_prompt_dict( + prompt: discuss_types.MessagePromptOptions = None, + *, + context: str | None = None, + examples: discuss_types.ExamplesOptions | None = None, + messages: discuss_types.MessagesOptions | None = None, +) -> glm.MessagePrompt: + """ + Creates a `glm.MessagePrompt` object from the provided prompt components. + + This function constructs a `glm.MessagePrompt` object using the provided `context`, `examples`, + or `messages`. It ensures the proper structure and handling of the input components. + + Either pass a `prompt` or it's component `context`, `examples`, `messages`. + + Args: + prompt: The complete prompt components. + context: The context for the prompt. + examples: The examples for the prompt. + messages: The messages for the prompt. + + Returns: + A `glm.MessagePrompt` object created from the provided prompt components. + """ + if prompt is None: + prompt = dict( + context=context, + examples=examples, + messages=messages, + ) + else: + flat_prompt = (context is not None) or (examples is not None) or (messages is not None) + if flat_prompt: + raise ValueError( + "You can't set `prompt`, and its fields `(context, examples, messages)`" + " at the same time" + ) + if isinstance(prompt, glm.MessagePrompt): + return prompt + elif isinstance(prompt, dict): # Always check dict before Iterable. + pass + else: + prompt = {"messages": prompt} + + keys = set(prompt.keys()) + if not keys.issubset(discuss_types.MESSAGE_PROMPT_KEYS): + raise KeyError( + f"Found extra entries in the prompt dictionary: {keys - discuss_types.MESSAGE_PROMPT_KEYS}" + ) + + examples = prompt.get("examples", None) + if examples is not None: + prompt["examples"] = _make_examples(examples) + messages = prompt.get("messages", None) + if messages is not None: + prompt["messages"] = _make_messages(messages) + + prompt = {k: v for k, v in prompt.items() if v is not None} + return prompt + + +def _make_message_prompt( + prompt: discuss_types.MessagePromptOptions = None, + *, + context: str | None = None, + examples: discuss_types.ExamplesOptions | None = None, + messages: discuss_types.MessagesOptions | None = None, +) -> glm.MessagePrompt: + """Creates a `glm.MessagePrompt` object from the provided prompt components.""" + prompt = _make_message_prompt_dict( + prompt=prompt, context=context, examples=examples, messages=messages + ) + return glm.MessagePrompt(prompt) + + +def _make_generate_message_request( + *, + model: model_types.AnyModelNameOptions | None, + context: str | None = None, + examples: discuss_types.ExamplesOptions | None = None, + messages: discuss_types.MessagesOptions | None = None, + temperature: float | None = None, + candidate_count: int | None = None, + top_p: float | None = None, + top_k: float | None = None, + prompt: discuss_types.MessagePromptOptions | None = None, +) -> glm.GenerateMessageRequest: + """Creates a `glm.GenerateMessageRequest` object for generating messages.""" + model = model_types.make_model_name(model) + + prompt = _make_message_prompt( + prompt=prompt, context=context, examples=examples, messages=messages + ) + + return glm.GenerateMessageRequest( + model=model, + prompt=prompt, + temperature=temperature, + top_p=top_p, + top_k=top_k, + candidate_count=candidate_count, + ) + + +DEFAULT_DISCUSS_MODEL = "models/chat-bison-001" + + +def chat( + *, + model: model_types.AnyModelNameOptions | None = "models/chat-bison-001", + context: str | None = None, + examples: discuss_types.ExamplesOptions | None = None, + messages: discuss_types.MessagesOptions | None = None, + temperature: float | None = None, + candidate_count: int | None = None, + top_p: float | None = None, + top_k: float | None = None, + prompt: discuss_types.MessagePromptOptions | None = None, + client: glm.DiscussServiceClient | None = None, + request_options: dict[str, Any] | None = None, +) -> discuss_types.ChatResponse: + """Calls the API and returns a `types.ChatResponse` containing the response. + + Args: + model: Which model to call, as a string or a `types.Model`. + context: Text that should be provided to the model first, to ground the response. + + If not empty, this `context` will be given to the model first before the + `examples` and `messages`. + + This field can be a description of your prompt to the model to help provide + context and guide the responses. + + Examples: + + * "Translate the phrase from English to French." + * "Given a statement, classify the sentiment as happy, sad or neutral." + + Anything included in this field will take precedence over history in `messages` + if the total input size exceeds the model's `Model.input_token_limit`. + examples: Examples of what the model should generate. + + This includes both the user input and the response that the model should + emulate. + + These `examples` are treated identically to conversation messages except + that they take precedence over the history in `messages`: + If the total input size exceeds the model's `input_token_limit` the input + will be truncated. Items will be dropped from `messages` before `examples` + messages: A snapshot of the conversation history sorted chronologically. + + Turns alternate between two authors. + + If the total input size exceeds the model's `input_token_limit` the input + will be truncated: The oldest items will be dropped from `messages`. + temperature: Controls the randomness of the output. Must be positive. + + Typical values are in the range: `[0.0,1.0]`. Higher values produce a + more random and varied response. A temperature of zero will be deterministic. + candidate_count: The **maximum** number of generated response messages to return. + + This value must be between `[1, 8]`, inclusive. If unset, this + will default to `1`. + + Note: Only unique candidates are returned. Higher temperatures are more + likely to produce unique candidates. Setting `temperature=0.0` will always + return 1 candidate regardless of the `candidate_count`. + top_k: The API uses combined [nucleus](https://arxiv.org/abs/1904.09751) and + top-k sampling. + + `top_k` sets the maximum number of tokens to sample from on each step. + top_p: The API uses combined [nucleus](https://arxiv.org/abs/1904.09751) and + top-k sampling. + + `top_p` configures the nucleus sampling. It sets the maximum cumulative + probability of tokens to sample from. + + For example, if the sorted probabilities are + `[0.5, 0.2, 0.1, 0.1, 0.05, 0.05]` a `top_p` of `0.8` will sample + as `[0.625, 0.25, 0.125, 0, 0, 0]`. + + Typical values are in the `[0.9, 1.0]` range. + prompt: You may pass a `types.MessagePromptOptions` **instead** of a + setting `context`/`examples`/`messages`, but not both. + client: If you're not relying on the default client, you pass a + `glm.DiscussServiceClient` instead. + request_options: Options for the request. + + Returns: + A `types.ChatResponse` containing the model's reply. + """ + request = _make_generate_message_request( + model=model, + context=context, + examples=examples, + messages=messages, + temperature=temperature, + candidate_count=candidate_count, + top_p=top_p, + top_k=top_k, + prompt=prompt, + ) + + return _generate_response(client=client, request=request, request_options=request_options) + + +@string_utils.set_doc(chat.__doc__) +async def chat_async( + *, + model: model_types.AnyModelNameOptions | None = "models/chat-bison-001", + context: str | None = None, + examples: discuss_types.ExamplesOptions | None = None, + messages: discuss_types.MessagesOptions | None = None, + temperature: float | None = None, + candidate_count: int | None = None, + top_p: float | None = None, + top_k: float | None = None, + prompt: discuss_types.MessagePromptOptions | None = None, + client: glm.DiscussServiceAsyncClient | None = None, + request_options: dict[str, Any] | None = None, +) -> discuss_types.ChatResponse: + request = _make_generate_message_request( + model=model, + context=context, + examples=examples, + messages=messages, + temperature=temperature, + candidate_count=candidate_count, + top_p=top_p, + top_k=top_k, + prompt=prompt, + ) + + return await _generate_response_async( + client=client, request=request, request_options=request_options + ) + + +if (sys.version_info.major, sys.version_info.minor) >= (3, 10): + DATACLASS_KWARGS = {"kw_only": True} +else: + DATACLASS_KWARGS = {} + + +@string_utils.prettyprint +@string_utils.set_doc(discuss_types.ChatResponse.__doc__) +@dataclasses.dataclass(**DATACLASS_KWARGS, init=False) +class ChatResponse(discuss_types.ChatResponse): + _client: glm.DiscussServiceClient | None = dataclasses.field(default=lambda: None, repr=False) + + def __init__(self, **kwargs): + for key, value in kwargs.items(): + setattr(self, key, value) + + @property + @string_utils.set_doc(discuss_types.ChatResponse.last.__doc__) + def last(self) -> str | None: + if self.messages[-1]: + return self.messages[-1]["content"] + else: + return None + + @last.setter + def last(self, message: discuss_types.MessageOptions): + message = _make_message(message) + message = type(message).to_dict(message) + self.messages[-1] = message + + @string_utils.set_doc(discuss_types.ChatResponse.reply.__doc__) + def reply( + self, + message: discuss_types.MessageOptions, + request_options: dict[str, Any] | None = None, + ) -> discuss_types.ChatResponse: + if isinstance(self._client, glm.DiscussServiceAsyncClient): + raise TypeError(f"reply can't be called on an async client, use reply_async instead.") + if self.last is None: + raise ValueError( + "The last response from the model did not return any candidates.\n" + "Check the `.filters` attribute to see why the responses were filtered:\n" + f"{self.filters}" + ) + + request = self.to_dict() + request.pop("candidates") + request.pop("filters", None) + request["messages"] = list(request["messages"]) + request["messages"].append(_make_message(message)) + request = _make_generate_message_request(**request) + return _generate_response( + request=request, client=self._client, request_options=request_options + ) + + @string_utils.set_doc(discuss_types.ChatResponse.reply.__doc__) + async def reply_async( + self, message: discuss_types.MessageOptions + ) -> discuss_types.ChatResponse: + if isinstance(self._client, glm.DiscussServiceClient): + raise TypeError( + f"reply_async can't be called on a non-async client, use reply instead." + ) + request = self.to_dict() + request.pop("candidates") + request.pop("filters", None) + request["messages"] = list(request["messages"]) + request["messages"].append(_make_message(message)) + request = _make_generate_message_request(**request) + return await _generate_response_async(request=request, client=self._client) + + +def _build_chat_response( + request: glm.GenerateMessageRequest, + response: glm.GenerateMessageResponse, + client: glm.DiscussServiceClient | glm.DiscussServiceAsyncClient, +) -> ChatResponse: + request = type(request).to_dict(request) + prompt = request.pop("prompt") + request["examples"] = prompt["examples"] + request["context"] = prompt["context"] + request["messages"] = prompt["messages"] + + response = type(response).to_dict(response) + response.pop("messages") + + response["filters"] = safety_types.convert_filters_to_enums(response["filters"]) + + if response["candidates"]: + last = response["candidates"][0] + else: + last = None + request["messages"].append(last) + request.setdefault("temperature", None) + request.setdefault("candidate_count", None) + + return ChatResponse(_client=client, **response, **request) # pytype: disable=missing-parameter + + +def _generate_response( + request: glm.GenerateMessageRequest, + client: glm.DiscussServiceClient | None = None, + request_options: dict[str, Any] | None = None, +) -> ChatResponse: + if request_options is None: + request_options = {} + + if client is None: + client = get_default_discuss_client() + + response = client.generate_message(request, **request_options) + + return _build_chat_response(request, response, client) + + +async def _generate_response_async( + request: glm.GenerateMessageRequest, + client: glm.DiscussServiceAsyncClient | None = None, + request_options: dict[str, Any] | None = None, +) -> ChatResponse: + if request_options is None: + request_options = {} + + if client is None: + client = get_default_discuss_async_client() + + response = await client.generate_message(request, **request_options) + + return _build_chat_response(request, response, client) + + +def count_message_tokens( + *, + prompt: discuss_types.MessagePromptOptions = None, + context: str | None = None, + examples: discuss_types.ExamplesOptions | None = None, + messages: discuss_types.MessagesOptions | None = None, + model: model_types.AnyModelNameOptions = DEFAULT_DISCUSS_MODEL, + client: glm.DiscussServiceAsyncClient | None = None, + request_options: dict[str, Any] | None = None, +) -> discuss_types.TokenCount: + model = model_types.make_model_name(model) + prompt = _make_message_prompt(prompt, context=context, examples=examples, messages=messages) + + if request_options is None: + request_options = {} + + if client is None: + client = get_default_discuss_client() + + result = client.count_message_tokens(model=model, prompt=prompt, **request_options) + + return type(result).to_dict(result) diff --git a/.venv/lib/python3.11/site-packages/google/generativeai/embedding.py b/.venv/lib/python3.11/site-packages/google/generativeai/embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..15645c79280987e213e159dffcbfc0cc63d38823 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/google/generativeai/embedding.py @@ -0,0 +1,312 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import itertools +from typing import Any, Iterable, overload, TypeVar, Union, Mapping + +import google.ai.generativelanguage as glm +from google.generativeai import protos + +from google.generativeai.client import get_default_generative_client +from google.generativeai.client import get_default_generative_async_client + +from google.generativeai.types import helper_types +from google.generativeai.types import model_types +from google.generativeai.types import text_types +from google.generativeai.types import content_types + +DEFAULT_EMB_MODEL = "models/embedding-001" +EMBEDDING_MAX_BATCH_SIZE = 100 + +EmbeddingTaskType = protos.TaskType + +EmbeddingTaskTypeOptions = Union[int, str, EmbeddingTaskType] + +_EMBEDDING_TASK_TYPE: dict[EmbeddingTaskTypeOptions, EmbeddingTaskType] = { + EmbeddingTaskType.TASK_TYPE_UNSPECIFIED: EmbeddingTaskType.TASK_TYPE_UNSPECIFIED, + 0: EmbeddingTaskType.TASK_TYPE_UNSPECIFIED, + "task_type_unspecified": EmbeddingTaskType.TASK_TYPE_UNSPECIFIED, + "unspecified": EmbeddingTaskType.TASK_TYPE_UNSPECIFIED, + EmbeddingTaskType.RETRIEVAL_QUERY: EmbeddingTaskType.RETRIEVAL_QUERY, + 1: EmbeddingTaskType.RETRIEVAL_QUERY, + "retrieval_query": EmbeddingTaskType.RETRIEVAL_QUERY, + "query": EmbeddingTaskType.RETRIEVAL_QUERY, + EmbeddingTaskType.RETRIEVAL_DOCUMENT: EmbeddingTaskType.RETRIEVAL_DOCUMENT, + 2: EmbeddingTaskType.RETRIEVAL_DOCUMENT, + "retrieval_document": EmbeddingTaskType.RETRIEVAL_DOCUMENT, + "document": EmbeddingTaskType.RETRIEVAL_DOCUMENT, + EmbeddingTaskType.SEMANTIC_SIMILARITY: EmbeddingTaskType.SEMANTIC_SIMILARITY, + 3: EmbeddingTaskType.SEMANTIC_SIMILARITY, + "semantic_similarity": EmbeddingTaskType.SEMANTIC_SIMILARITY, + "similarity": EmbeddingTaskType.SEMANTIC_SIMILARITY, + EmbeddingTaskType.CLASSIFICATION: EmbeddingTaskType.CLASSIFICATION, + 4: EmbeddingTaskType.CLASSIFICATION, + "classification": EmbeddingTaskType.CLASSIFICATION, + EmbeddingTaskType.CLUSTERING: EmbeddingTaskType.CLUSTERING, + 5: EmbeddingTaskType.CLUSTERING, + "clustering": EmbeddingTaskType.CLUSTERING, + 6: EmbeddingTaskType.QUESTION_ANSWERING, + "question_answering": EmbeddingTaskType.QUESTION_ANSWERING, + "qa": EmbeddingTaskType.QUESTION_ANSWERING, + EmbeddingTaskType.QUESTION_ANSWERING: EmbeddingTaskType.QUESTION_ANSWERING, + 7: EmbeddingTaskType.FACT_VERIFICATION, + "fact_verification": EmbeddingTaskType.FACT_VERIFICATION, + "verification": EmbeddingTaskType.FACT_VERIFICATION, + EmbeddingTaskType.FACT_VERIFICATION: EmbeddingTaskType.FACT_VERIFICATION, +} + + +def to_task_type(x: EmbeddingTaskTypeOptions) -> EmbeddingTaskType: + if isinstance(x, str): + x = x.lower() + return _EMBEDDING_TASK_TYPE[x] + + +try: + # python 3.12+ + _batched = itertools.batched # type: ignore +except AttributeError: + T = TypeVar("T") + + def _batched(iterable: Iterable[T], n: int) -> Iterable[list[T]]: + if n < 1: + raise ValueError( + f"Invalid input: The batch size 'n' must be a positive integer. You entered: {n}. Please enter a number greater than 0." + ) + batch = [] + for item in iterable: + batch.append(item) + if len(batch) == n: + yield batch + batch = [] + + if batch: + yield batch + + +@overload +def embed_content( + model: model_types.BaseModelNameOptions, + content: content_types.ContentType, + task_type: EmbeddingTaskTypeOptions | None = None, + title: str | None = None, + output_dimensionality: int | None = None, + client: glm.GenerativeServiceClient | None = None, + request_options: helper_types.RequestOptionsType | None = None, +) -> text_types.EmbeddingDict: ... + + +@overload +def embed_content( + model: model_types.BaseModelNameOptions, + content: Iterable[content_types.ContentType], + task_type: EmbeddingTaskTypeOptions | None = None, + title: str | None = None, + output_dimensionality: int | None = None, + client: glm.GenerativeServiceClient | None = None, + request_options: helper_types.RequestOptionsType | None = None, +) -> text_types.BatchEmbeddingDict: ... + + +def embed_content( + model: model_types.BaseModelNameOptions, + content: content_types.ContentType | Iterable[content_types.ContentType], + task_type: EmbeddingTaskTypeOptions | None = None, + title: str | None = None, + output_dimensionality: int | None = None, + client: glm.GenerativeServiceClient = None, + request_options: helper_types.RequestOptionsType | None = None, +) -> text_types.EmbeddingDict | text_types.BatchEmbeddingDict: + """Calls the API to create embeddings for content passed in. + + Args: + model: + Which [model](https://ai.google.dev/models/gemini#embedding) to + call, as a string or a `types.Model`. + + content: + Content to embed. + + task_type: + Optional task type for which the embeddings will be used. Can only + be set for `models/embedding-001`. + + title: + An optional title for the text. Only applicable when task_type is + `RETRIEVAL_DOCUMENT`. + + output_dimensionality: + Optional reduced dimensionality for the output embeddings. If set, + excessive values from the output embeddings will be truncated from + the end. + + request_options: + Options for the request. + + Return: + Dictionary containing the embedding (list of float values) for the + input content. + """ + model = model_types.make_model_name(model) + + if request_options is None: + request_options = {} + + if client is None: + client = get_default_generative_client() + + if title and to_task_type(task_type) is not EmbeddingTaskType.RETRIEVAL_DOCUMENT: + raise ValueError( + f"Invalid task type: When a title is specified, the task must be of a 'retrieval document' type. Received task type: {task_type} and title: {title}." + ) + + if output_dimensionality and output_dimensionality < 0: + raise ValueError( + f"Invalid value: `output_dimensionality` must be a non-negative integer. Received: {output_dimensionality}." + ) + + if task_type: + task_type = to_task_type(task_type) + + if isinstance(content, Iterable) and not isinstance(content, (str, Mapping)): + result = {"embedding": []} + requests = ( + protos.EmbedContentRequest( + model=model, + content=content_types.to_content(c), + task_type=task_type, + title=title, + output_dimensionality=output_dimensionality, + ) + for c in content + ) + for batch in _batched(requests, EMBEDDING_MAX_BATCH_SIZE): + embedding_request = protos.BatchEmbedContentsRequest(model=model, requests=batch) + embedding_response = client.batch_embed_contents( + embedding_request, + **request_options, + ) + embedding_dict = type(embedding_response).to_dict(embedding_response) + result["embedding"].extend(e["values"] for e in embedding_dict["embeddings"]) + return result + else: + embedding_request = protos.EmbedContentRequest( + model=model, + content=content_types.to_content(content), + task_type=task_type, + title=title, + output_dimensionality=output_dimensionality, + ) + embedding_response = client.embed_content( + embedding_request, + **request_options, + ) + embedding_dict = type(embedding_response).to_dict(embedding_response) + embedding_dict["embedding"] = embedding_dict["embedding"]["values"] + return embedding_dict + + +@overload +async def embed_content_async( + model: model_types.BaseModelNameOptions, + content: content_types.ContentType, + task_type: EmbeddingTaskTypeOptions | None = None, + title: str | None = None, + output_dimensionality: int | None = None, + client: glm.GenerativeServiceAsyncClient | None = None, + request_options: helper_types.RequestOptionsType | None = None, +) -> text_types.EmbeddingDict: ... + + +@overload +async def embed_content_async( + model: model_types.BaseModelNameOptions, + content: Iterable[content_types.ContentType], + task_type: EmbeddingTaskTypeOptions | None = None, + title: str | None = None, + output_dimensionality: int | None = None, + client: glm.GenerativeServiceAsyncClient | None = None, + request_options: helper_types.RequestOptionsType | None = None, +) -> text_types.BatchEmbeddingDict: ... + + +async def embed_content_async( + model: model_types.BaseModelNameOptions, + content: content_types.ContentType | Iterable[content_types.ContentType], + task_type: EmbeddingTaskTypeOptions | None = None, + title: str | None = None, + output_dimensionality: int | None = None, + client: glm.GenerativeServiceAsyncClient = None, + request_options: helper_types.RequestOptionsType | None = None, +) -> text_types.EmbeddingDict | text_types.BatchEmbeddingDict: + """Calls the API to create async embeddings for content passed in.""" + + model = model_types.make_model_name(model) + + if request_options is None: + request_options = {} + + if client is None: + client = get_default_generative_async_client() + + if title and to_task_type(task_type) is not EmbeddingTaskType.RETRIEVAL_DOCUMENT: + raise ValueError( + f"Invalid task type: When a title is specified, the task must be of a 'retrieval document' type. Received task type: {task_type} and title: {title}." + ) + if output_dimensionality and output_dimensionality < 0: + raise ValueError( + f"Invalid value: `output_dimensionality` must be a non-negative integer. Received: {output_dimensionality}." + ) + + if task_type: + task_type = to_task_type(task_type) + + if isinstance(content, Iterable) and not isinstance(content, (str, Mapping)): + result = {"embedding": []} + requests = ( + protos.EmbedContentRequest( + model=model, + content=content_types.to_content(c), + task_type=task_type, + title=title, + output_dimensionality=output_dimensionality, + ) + for c in content + ) + for batch in _batched(requests, EMBEDDING_MAX_BATCH_SIZE): + embedding_request = protos.BatchEmbedContentsRequest(model=model, requests=batch) + embedding_response = await client.batch_embed_contents( + embedding_request, + **request_options, + ) + embedding_dict = type(embedding_response).to_dict(embedding_response) + result["embedding"].extend(e["values"] for e in embedding_dict["embeddings"]) + return result + else: + embedding_request = protos.EmbedContentRequest( + model=model, + content=content_types.to_content(content), + task_type=task_type, + title=title, + output_dimensionality=output_dimensionality, + ) + embedding_response = await client.embed_content( + embedding_request, + **request_options, + ) + embedding_dict = type(embedding_response).to_dict(embedding_response) + embedding_dict["embedding"] = embedding_dict["embedding"]["values"] + return embedding_dict diff --git a/.venv/lib/python3.11/site-packages/google/generativeai/files.py b/.venv/lib/python3.11/site-packages/google/generativeai/files.py new file mode 100644 index 0000000000000000000000000000000000000000..b2581bdcd38b708238ecd7e38f16b5b309c412df --- /dev/null +++ b/.venv/lib/python3.11/site-packages/google/generativeai/files.py @@ -0,0 +1,116 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import os +import pathlib +import mimetypes +from typing import Iterable +import logging +from google.generativeai import protos +from itertools import islice +from io import IOBase + +from google.generativeai.types import file_types + +from google.generativeai.client import get_default_file_client + +__all__ = ["upload_file", "get_file", "list_files", "delete_file"] + +mimetypes.add_type("image/webp", ".webp") + + +def upload_file( + path: str | pathlib.Path | os.PathLike | IOBase, + *, + mime_type: str | None = None, + name: str | None = None, + display_name: str | None = None, + resumable: bool = True, +) -> file_types.File: + """Calls the API to upload a file using a supported file service. + + Args: + path: The path to the file or a file-like object (e.g., BytesIO) to be uploaded. + mime_type: The MIME type of the file. If not provided, it will be + inferred from the file extension. + name: The name of the file in the destination (e.g., 'files/sample-image'). + If not provided, a system generated ID will be created. + display_name: Optional display name of the file. + resumable: Whether to use the resumable upload protocol. By default, this is enabled. + See details at + https://googleapis.github.io/google-api-python-client/docs/epy/googleapiclient.http.MediaFileUpload-class.html#resumable + + Returns: + file_types.File: The response of the uploaded file. + """ + client = get_default_file_client() + + if isinstance(path, IOBase): + if mime_type is None: + raise ValueError( + "Unknown mime type: When passing a file like object to `path` (instead of a\n" + " path-like object) you must set the `mime_type` argument" + ) + else: + path = pathlib.Path(os.fspath(path)) + + if display_name is None: + display_name = path.name + + if mime_type is None: + mime_type, _ = mimetypes.guess_type(path) + + if mime_type is None: + raise ValueError( + "Unknown mime type: Could not determine the mimetype for your file\n" + " please set the `mime_type` argument" + ) + + if name is not None and "/" not in name: + name = f"files/{name}" + + response = client.create_file( + path=path, mime_type=mime_type, name=name, display_name=display_name, resumable=resumable + ) + return file_types.File(response) + + +def list_files(page_size=100) -> Iterable[file_types.File]: + """Calls the API to list files using a supported file service.""" + client = get_default_file_client() + + response = client.list_files(protos.ListFilesRequest(page_size=page_size)) + for proto in response: + yield file_types.File(proto) + + +def get_file(name: str) -> file_types.File: + """Calls the API to retrieve a specified file using a supported file service.""" + if "/" not in name: + name = f"files/{name}" + client = get_default_file_client() + return file_types.File(client.get_file(name=name)) + + +def delete_file(name: str | file_types.File | protos.File): + """Calls the API to permanently delete a specified file using a supported file service.""" + if isinstance(name, (file_types.File, protos.File)): + name = name.name + elif "/" not in name: + name = f"files/{name}" + request = protos.DeleteFileRequest(name=name) + client = get_default_file_client() + client.delete_file(request=request) diff --git a/.venv/lib/python3.11/site-packages/google/generativeai/generative_models.py b/.venv/lib/python3.11/site-packages/google/generativeai/generative_models.py new file mode 100644 index 0000000000000000000000000000000000000000..8d331a9f619b25a279af5cc62fea652488e1ad84 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/google/generativeai/generative_models.py @@ -0,0 +1,875 @@ +"""Classes for working with the Gemini models.""" + +from __future__ import annotations + +from collections.abc import Iterable +import textwrap +from typing import Any, Union, overload +import reprlib + +# pylint: disable=bad-continuation, line-too-long + + +import google.api_core.exceptions +from google.generativeai import protos +from google.generativeai import client + +from google.generativeai import caching +from google.generativeai.types import content_types +from google.generativeai.types import generation_types +from google.generativeai.types import helper_types +from google.generativeai.types import safety_types + +_USER_ROLE = "user" +_MODEL_ROLE = "model" + + +class GenerativeModel: + """ + The `genai.GenerativeModel` class wraps default parameters for calls to + `GenerativeModel.generate_content`, `GenerativeModel.count_tokens`, and + `GenerativeModel.start_chat`. + + This family of functionality is designed to support multi-turn conversations, and multimodal + requests. What media-types are supported for input and output is model-dependant. + + >>> import google.generativeai as genai + >>> import PIL.Image + >>> genai.configure(api_key='YOUR_API_KEY') + >>> model = genai.GenerativeModel('models/gemini-1.5-flash') + >>> result = model.generate_content('Tell me a story about a magic backpack') + >>> result.text + "In the quaint little town of Lakeside, there lived a young girl named Lily..." + + Multimodal input: + + >>> model = genai.GenerativeModel('models/gemini-1.5-flash') + >>> result = model.generate_content([ + ... "Give me a recipe for these:", PIL.Image.open('scones.jpeg')]) + >>> result.text + "**Blueberry Scones** ..." + + Multi-turn conversation: + + >>> chat = model.start_chat() + >>> response = chat.send_message("Hi, I have some questions for you.") + >>> response.text + "Sure, I'll do my best to answer your questions..." + + To list the compatible model names use: + + >>> for m in genai.list_models(): + ... if 'generateContent' in m.supported_generation_methods: + ... print(m.name) + + Arguments: + model_name: The name of the model to query. To list compatible models use + safety_settings: Sets the default safety filters. This controls which content is blocked + by the api before being returned. + generation_config: A `genai.GenerationConfig` setting the default generation parameters to + use. + """ + + def __init__( + self, + model_name: str = "gemini-1.5-flash-002", + safety_settings: safety_types.SafetySettingOptions | None = None, + generation_config: generation_types.GenerationConfigType | None = None, + tools: content_types.FunctionLibraryType | None = None, + tool_config: content_types.ToolConfigType | None = None, + system_instruction: content_types.ContentType | None = None, + ): + if "/" not in model_name: + model_name = "models/" + model_name + self._model_name = model_name + self._safety_settings = safety_types.to_easy_safety_dict(safety_settings) + self._generation_config = generation_types.to_generation_config_dict(generation_config) + self._tools = content_types.to_function_library(tools) + + if tool_config is None: + self._tool_config = None + else: + self._tool_config = content_types.to_tool_config(tool_config) + + if system_instruction is None: + self._system_instruction = None + else: + self._system_instruction = content_types.to_content(system_instruction) + + self._client = None + self._async_client = None + + @property + def cached_content(self) -> str: + return getattr(self, "_cached_content", None) + + @property + def model_name(self): + return self._model_name + + def __str__(self): + def maybe_text(content): + if content and len(content.parts) and (t := content.parts[0].text): + return repr(t) + return content + + return textwrap.dedent( + f"""\ + genai.GenerativeModel( + model_name='{self.model_name}', + generation_config={self._generation_config}, + safety_settings={self._safety_settings}, + tools={self._tools}, + system_instruction={maybe_text(self._system_instruction)}, + cached_content={self.cached_content} + )""" + ) + + __repr__ = __str__ + + def _prepare_request( + self, + *, + contents: content_types.ContentsType, + generation_config: generation_types.GenerationConfigType | None = None, + safety_settings: safety_types.SafetySettingOptions | None = None, + tools: content_types.FunctionLibraryType | None, + tool_config: content_types.ToolConfigType | None, + ) -> protos.GenerateContentRequest: + """Creates a `protos.GenerateContentRequest` from raw inputs.""" + if hasattr(self, "_cached_content") and any([self._system_instruction, tools, tool_config]): + raise ValueError( + "`tools`, `tool_config`, `system_instruction` cannot be set on a model instantiated with `cached_content` as its context." + ) + + tools_lib = self._get_tools_lib(tools) + if tools_lib is not None: + tools_lib = tools_lib.to_proto() + + if tool_config is None: + tool_config = self._tool_config + else: + tool_config = content_types.to_tool_config(tool_config) + + contents = content_types.to_contents(contents) + + generation_config = generation_types.to_generation_config_dict(generation_config) + merged_gc = self._generation_config.copy() + merged_gc.update(generation_config) + + safety_settings = safety_types.to_easy_safety_dict(safety_settings) + merged_ss = self._safety_settings.copy() + merged_ss.update(safety_settings) + merged_ss = safety_types.normalize_safety_settings(merged_ss) + + return protos.GenerateContentRequest( + model=self._model_name, + contents=contents, + generation_config=merged_gc, + safety_settings=merged_ss, + tools=tools_lib, + tool_config=tool_config, + system_instruction=self._system_instruction, + cached_content=self.cached_content, + ) + + def _get_tools_lib( + self, tools: content_types.FunctionLibraryType + ) -> content_types.FunctionLibrary | None: + if tools is None: + return self._tools + else: + return content_types.to_function_library(tools) + + @overload + @classmethod + def from_cached_content( + cls, + cached_content: str, + *, + generation_config: generation_types.GenerationConfigType | None = None, + safety_settings: safety_types.SafetySettingOptions | None = None, + ) -> GenerativeModel: ... + + @overload + @classmethod + def from_cached_content( + cls, + cached_content: caching.CachedContent, + *, + generation_config: generation_types.GenerationConfigType | None = None, + safety_settings: safety_types.SafetySettingOptions | None = None, + ) -> GenerativeModel: ... + + @classmethod + def from_cached_content( + cls, + cached_content: str | caching.CachedContent, + *, + generation_config: generation_types.GenerationConfigType | None = None, + safety_settings: safety_types.SafetySettingOptions | None = None, + ) -> GenerativeModel: + """Creates a model with `cached_content` as model's context. + + Args: + cached_content: context for the model. + generation_config: Overrides for the model's generation config. + safety_settings: Overrides for the model's safety settings. + + Returns: + `GenerativeModel` object with `cached_content` as its context. + """ + if isinstance(cached_content, str): + cached_content = caching.CachedContent.get(name=cached_content) + + # call __init__ to set the model's `generation_config`, `safety_settings`. + # `model_name` will be the name of the model for which the `cached_content` was created. + self = cls( + model_name=cached_content.model, + generation_config=generation_config, + safety_settings=safety_settings, + ) + + # set the model's context. + setattr(self, "_cached_content", cached_content.name) + return self + + def generate_content( + self, + contents: content_types.ContentsType, + *, + generation_config: generation_types.GenerationConfigType | None = None, + safety_settings: safety_types.SafetySettingOptions | None = None, + stream: bool = False, + tools: content_types.FunctionLibraryType | None = None, + tool_config: content_types.ToolConfigType | None = None, + request_options: helper_types.RequestOptionsType | None = None, + ) -> generation_types.GenerateContentResponse: + """A multipurpose function to generate responses from the model. + + This `GenerativeModel.generate_content` method can handle multimodal input, and multi-turn + conversations. + + >>> model = genai.GenerativeModel('models/gemini-1.5-flash') + >>> response = model.generate_content('Tell me a story about a magic backpack') + >>> response.text + + ### Streaming + + This method supports streaming with the `stream=True`. The result has the same type as the non streaming case, + but you can iterate over the response chunks as they become available: + + >>> response = model.generate_content('Tell me a story about a magic backpack', stream=True) + >>> for chunk in response: + ... print(chunk.text) + + ### Multi-turn + + This method supports multi-turn chats but is **stateless**: the entire conversation history needs to be sent with each + request. This takes some manual management but gives you complete control: + + >>> messages = [{'role':'user', 'parts': ['hello']}] + >>> response = model.generate_content(messages) # "Hello, how can I help" + >>> messages.append(response.candidates[0].content) + >>> messages.append({'role':'user', 'parts': ['How does quantum physics work?']}) + >>> response = model.generate_content(messages) + + For a simpler multi-turn interface see `GenerativeModel.start_chat`. + + ### Input type flexibility + + While the underlying API strictly expects a `list[protos.Content]` objects, this method + will convert the user input into the correct type. The hierarchy of types that can be + converted is below. Any of these objects can be passed as an equivalent `dict`. + + * `Iterable[protos.Content]` + * `protos.Content` + * `Iterable[protos.Part]` + * `protos.Part` + * `str`, `Image`, or `protos.Blob` + + In an `Iterable[protos.Content]` each `content` is a separate message. + But note that an `Iterable[protos.Part]` is taken as the parts of a single message. + + Arguments: + contents: The contents serving as the model's prompt. + generation_config: Overrides for the model's generation config. + safety_settings: Overrides for the model's safety settings. + stream: If True, yield response chunks as they are generated. + tools: `protos.Tools` more info coming soon. + request_options: Options for the request. + """ + if not contents: + raise TypeError("contents must not be empty") + + request = self._prepare_request( + contents=contents, + generation_config=generation_config, + safety_settings=safety_settings, + tools=tools, + tool_config=tool_config, + ) + + if request.contents and not request.contents[-1].role: + request.contents[-1].role = _USER_ROLE + + if self._client is None: + self._client = client.get_default_generative_client() + + if request_options is None: + request_options = {} + + try: + if stream: + with generation_types.rewrite_stream_error(): + iterator = self._client.stream_generate_content( + request, + **request_options, + ) + return generation_types.GenerateContentResponse.from_iterator(iterator) + else: + response = self._client.generate_content( + request, + **request_options, + ) + return generation_types.GenerateContentResponse.from_response(response) + except google.api_core.exceptions.InvalidArgument as e: + if e.message.startswith("Request payload size exceeds the limit:"): + e.message += ( + " The file size is too large. Please use the File API to upload your files instead. " + "Example: `f = genai.upload_file(path); m.generate_content(['tell me about this file:', f])`" + ) + raise + + async def generate_content_async( + self, + contents: content_types.ContentsType, + *, + generation_config: generation_types.GenerationConfigType | None = None, + safety_settings: safety_types.SafetySettingOptions | None = None, + stream: bool = False, + tools: content_types.FunctionLibraryType | None = None, + tool_config: content_types.ToolConfigType | None = None, + request_options: helper_types.RequestOptionsType | None = None, + ) -> generation_types.AsyncGenerateContentResponse: + """The async version of `GenerativeModel.generate_content`.""" + if not contents: + raise TypeError("contents must not be empty") + + request = self._prepare_request( + contents=contents, + generation_config=generation_config, + safety_settings=safety_settings, + tools=tools, + tool_config=tool_config, + ) + + if request.contents and not request.contents[-1].role: + request.contents[-1].role = _USER_ROLE + + if self._async_client is None: + self._async_client = client.get_default_generative_async_client() + + if request_options is None: + request_options = {} + + try: + if stream: + with generation_types.rewrite_stream_error(): + iterator = await self._async_client.stream_generate_content( + request, + **request_options, + ) + return await generation_types.AsyncGenerateContentResponse.from_aiterator(iterator) + else: + response = await self._async_client.generate_content( + request, + **request_options, + ) + return generation_types.AsyncGenerateContentResponse.from_response(response) + except google.api_core.exceptions.InvalidArgument as e: + if e.message.startswith("Request payload size exceeds the limit:"): + e.message += ( + " The file size is too large. Please use the File API to upload your files instead. " + "Example: `f = genai.upload_file(path); m.generate_content(['tell me about this file:', f])`" + ) + raise + + # fmt: off + def count_tokens( + self, + contents: content_types.ContentsType = None, + *, + generation_config: generation_types.GenerationConfigType | None = None, + safety_settings: safety_types.SafetySettingOptions | None = None, + tools: content_types.FunctionLibraryType | None = None, + tool_config: content_types.ToolConfigType | None = None, + request_options: helper_types.RequestOptionsType | None = None, + ) -> protos.CountTokensResponse: + if request_options is None: + request_options = {} + + if self._client is None: + self._client = client.get_default_generative_client() + + request = protos.CountTokensRequest( + model=self.model_name, + generate_content_request=self._prepare_request( + contents=contents, + generation_config=generation_config, + safety_settings=safety_settings, + tools=tools, + tool_config=tool_config, + )) + return self._client.count_tokens(request, **request_options) + + async def count_tokens_async( + self, + contents: content_types.ContentsType = None, + *, + generation_config: generation_types.GenerationConfigType | None = None, + safety_settings: safety_types.SafetySettingOptions | None = None, + tools: content_types.FunctionLibraryType | None = None, + tool_config: content_types.ToolConfigType | None = None, + request_options: helper_types.RequestOptionsType | None = None, + ) -> protos.CountTokensResponse: + if request_options is None: + request_options = {} + + if self._async_client is None: + self._async_client = client.get_default_generative_async_client() + + request = protos.CountTokensRequest( + model=self.model_name, + generate_content_request=self._prepare_request( + contents=contents, + generation_config=generation_config, + safety_settings=safety_settings, + tools=tools, + tool_config=tool_config, + )) + return await self._async_client.count_tokens(request, **request_options) + + # fmt: on + + def start_chat( + self, + *, + history: Iterable[content_types.StrictContentType] | None = None, + enable_automatic_function_calling: bool = False, + ) -> ChatSession: + """Returns a `genai.ChatSession` attached to this model. + + >>> model = genai.GenerativeModel() + >>> chat = model.start_chat(history=[...]) + >>> response = chat.send_message("Hello?") + + Arguments: + history: An iterable of `protos.Content` objects, or equivalents to initialize the session. + """ + if self._generation_config.get("candidate_count", 1) > 1: + raise ValueError( + "Invalid configuration: The chat functionality does not support `candidate_count` greater than 1." + ) + return ChatSession( + model=self, + history=history, + enable_automatic_function_calling=enable_automatic_function_calling, + ) + + +class ChatSession: + """Contains an ongoing conversation with the model. + + >>> model = genai.GenerativeModel('models/gemini-1.5-flash') + >>> chat = model.start_chat() + >>> response = chat.send_message("Hello") + >>> print(response.text) + >>> response = chat.send_message("Hello again") + >>> print(response.text) + >>> response = chat.send_message(... + + This `ChatSession` object collects the messages sent and received, in its + `ChatSession.history` attribute. + + Arguments: + model: The model to use in the chat. + history: A chat history to initialize the object with. + """ + + def __init__( + self, + model: GenerativeModel, + history: Iterable[content_types.StrictContentType] | None = None, + enable_automatic_function_calling: bool = False, + ): + self.model: GenerativeModel = model + self._history: list[protos.Content] = content_types.to_contents(history) + self._last_sent: protos.Content | None = None + self._last_received: generation_types.BaseGenerateContentResponse | None = None + self.enable_automatic_function_calling = enable_automatic_function_calling + + def send_message( + self, + content: content_types.ContentType, + *, + generation_config: generation_types.GenerationConfigType = None, + safety_settings: safety_types.SafetySettingOptions = None, + stream: bool = False, + tools: content_types.FunctionLibraryType | None = None, + tool_config: content_types.ToolConfigType | None = None, + request_options: helper_types.RequestOptionsType | None = None, + ) -> generation_types.GenerateContentResponse: + """Sends the conversation history with the added message and returns the model's response. + + Appends the request and response to the conversation history. + + >>> model = genai.GenerativeModel('models/gemini-1.5-flash') + >>> chat = model.start_chat() + >>> response = chat.send_message("Hello") + >>> print(response.text) + "Hello! How can I assist you today?" + >>> len(chat.history) + 2 + + Call it with `stream=True` to receive response chunks as they are generated: + + >>> chat = model.start_chat() + >>> response = chat.send_message("Explain quantum physics", stream=True) + >>> for chunk in response: + ... print(chunk.text, end='') + + Once iteration over chunks is complete, the `response` and `ChatSession` are in states identical to the + `stream=False` case. Some properties are not available until iteration is complete. + + Like `GenerativeModel.generate_content` this method lets you override the model's `generation_config` and + `safety_settings`. + + Arguments: + content: The message contents. + generation_config: Overrides for the model's generation config. + safety_settings: Overrides for the model's safety settings. + stream: If True, yield response chunks as they are generated. + """ + if request_options is None: + request_options = {} + + if self.enable_automatic_function_calling and stream: + raise NotImplementedError( + "Unsupported configuration: The `google.generativeai` SDK currently does not support the combination of `stream=True` and `enable_automatic_function_calling=True`." + ) + + tools_lib = self.model._get_tools_lib(tools) + + content = content_types.to_content(content) + + if not content.role: + content.role = _USER_ROLE + + history = self.history[:] + history.append(content) + + generation_config = generation_types.to_generation_config_dict(generation_config) + if generation_config.get("candidate_count", 1) > 1: + raise ValueError( + "Invalid configuration: The chat functionality does not support `candidate_count` greater than 1." + ) + + response = self.model.generate_content( + contents=history, + generation_config=generation_config, + safety_settings=safety_settings, + stream=stream, + tools=tools_lib, + tool_config=tool_config, + request_options=request_options, + ) + + self._check_response(response=response, stream=stream) + + if self.enable_automatic_function_calling and tools_lib is not None: + self.history, content, response = self._handle_afc( + response=response, + history=history, + generation_config=generation_config, + safety_settings=safety_settings, + stream=stream, + tools_lib=tools_lib, + request_options=request_options, + ) + + self._last_sent = content + self._last_received = response + + return response + + def _check_response(self, *, response, stream): + if response.prompt_feedback.block_reason: + raise generation_types.BlockedPromptException(response.prompt_feedback) + + if not stream: + if response.candidates[0].finish_reason not in ( + protos.Candidate.FinishReason.FINISH_REASON_UNSPECIFIED, + protos.Candidate.FinishReason.STOP, + protos.Candidate.FinishReason.MAX_TOKENS, + ): + raise generation_types.StopCandidateException(response.candidates[0]) + + def _get_function_calls(self, response) -> list[protos.FunctionCall]: + candidates = response.candidates + if len(candidates) != 1: + raise ValueError( + f"Invalid number of candidates: Automatic function calling only works with 1 candidate, but {len(candidates)} were provided." + ) + parts = candidates[0].content.parts + function_calls = [part.function_call for part in parts if part and "function_call" in part] + return function_calls + + def _handle_afc( + self, + *, + response, + history, + generation_config, + safety_settings, + stream, + tools_lib, + request_options, + ) -> tuple[list[protos.Content], protos.Content, generation_types.BaseGenerateContentResponse]: + + while function_calls := self._get_function_calls(response): + if not all(callable(tools_lib[fc]) for fc in function_calls): + break + history.append(response.candidates[0].content) + + function_response_parts: list[protos.Part] = [] + for fc in function_calls: + fr = tools_lib(fc) + assert fr is not None, ( + "Unexpected state: The function reference (fr) should never be None. It should only return None if the declaration " + "is not callable, which is checked earlier in the code." + ) + function_response_parts.append(fr) + + send = protos.Content(role=_USER_ROLE, parts=function_response_parts) + history.append(send) + + response = self.model.generate_content( + contents=history, + generation_config=generation_config, + safety_settings=safety_settings, + stream=stream, + tools=tools_lib, + request_options=request_options, + ) + + self._check_response(response=response, stream=stream) + + *history, content = history + return history, content, response + + async def send_message_async( + self, + content: content_types.ContentType, + *, + generation_config: generation_types.GenerationConfigType = None, + safety_settings: safety_types.SafetySettingOptions = None, + stream: bool = False, + tools: content_types.FunctionLibraryType | None = None, + tool_config: content_types.ToolConfigType | None = None, + request_options: helper_types.RequestOptionsType | None = None, + ) -> generation_types.AsyncGenerateContentResponse: + """The async version of `ChatSession.send_message`.""" + if request_options is None: + request_options = {} + + if self.enable_automatic_function_calling and stream: + raise NotImplementedError( + "Unsupported configuration: The `google.generativeai` SDK currently does not support the combination of `stream=True` and `enable_automatic_function_calling=True`." + ) + + tools_lib = self.model._get_tools_lib(tools) + + content = content_types.to_content(content) + + if not content.role: + content.role = _USER_ROLE + + history = self.history[:] + history.append(content) + + generation_config = generation_types.to_generation_config_dict(generation_config) + if generation_config.get("candidate_count", 1) > 1: + raise ValueError( + "Invalid configuration: The chat functionality does not support `candidate_count` greater than 1." + ) + + response = await self.model.generate_content_async( + contents=history, + generation_config=generation_config, + safety_settings=safety_settings, + stream=stream, + tools=tools_lib, + tool_config=tool_config, + request_options=request_options, + ) + + self._check_response(response=response, stream=stream) + + if self.enable_automatic_function_calling and tools_lib is not None: + self.history, content, response = await self._handle_afc_async( + response=response, + history=history, + generation_config=generation_config, + safety_settings=safety_settings, + stream=stream, + tools_lib=tools_lib, + request_options=request_options, + ) + + self._last_sent = content + self._last_received = response + + return response + + async def _handle_afc_async( + self, + *, + response, + history, + generation_config, + safety_settings, + stream, + tools_lib, + request_options, + ) -> tuple[list[protos.Content], protos.Content, generation_types.BaseGenerateContentResponse]: + + while function_calls := self._get_function_calls(response): + if not all(callable(tools_lib[fc]) for fc in function_calls): + break + history.append(response.candidates[0].content) + + function_response_parts: list[protos.Part] = [] + for fc in function_calls: + fr = tools_lib(fc) + assert fr is not None, ( + "Unexpected state: The function reference (fr) should never be None. It should only return None if the declaration " + "is not callable, which is checked earlier in the code." + ) + function_response_parts.append(fr) + + send = protos.Content(role=_USER_ROLE, parts=function_response_parts) + history.append(send) + + response = await self.model.generate_content_async( + contents=history, + generation_config=generation_config, + safety_settings=safety_settings, + stream=stream, + tools=tools_lib, + request_options=request_options, + ) + + self._check_response(response=response, stream=stream) + + *history, content = history + return history, content, response + + def __copy__(self): + return ChatSession( + model=self.model, + # Be sure the copy doesn't share the history. + history=list(self.history), + ) + + def rewind(self) -> tuple[protos.Content, protos.Content]: + """Removes the last request/response pair from the chat history.""" + if self._last_received is None: + result = self._history.pop(-2), self._history.pop() + return result + else: + result = self._last_sent, self._last_received.candidates[0].content + self._last_sent = None + self._last_received = None + return result + + @property + def last(self) -> generation_types.BaseGenerateContentResponse | None: + """returns the last received `genai.GenerateContentResponse`""" + return self._last_received + + @property + def history(self) -> list[protos.Content]: + """The chat history.""" + last = self._last_received + if last is None: + return self._history + + if last.candidates[0].finish_reason not in ( + protos.Candidate.FinishReason.FINISH_REASON_UNSPECIFIED, + protos.Candidate.FinishReason.STOP, + protos.Candidate.FinishReason.MAX_TOKENS, + ): + error = generation_types.StopCandidateException(last.candidates[0]) + last._error = error + + if last._error is not None: + raise generation_types.BrokenResponseError( + "Unable to build a coherent chat history due to a broken streaming response. " + "Refer to the previous exception for details. " + "To inspect the last response object, use `chat.last`. " + "To remove the last request/response `Content` objects from the chat, " + "call `last_send, last_received = chat.rewind()` and continue without it." + ) from last._error + + sent = self._last_sent + received = last.candidates[0].content + if not received.role: + received.role = _MODEL_ROLE + self._history.extend([sent, received]) + + self._last_sent = None + self._last_received = None + + return self._history + + @history.setter + def history(self, history): + self._history = content_types.to_contents(history) + self._last_sent = None + self._last_received = None + + def __repr__(self) -> str: + _dict_repr = reprlib.Repr() + _model = str(self.model).replace("\n", "\n" + " " * 4) + + def content_repr(x): + return f"protos.Content({_dict_repr.repr(type(x).to_dict(x))})" + + try: + history = list(self.history) + except (generation_types.BrokenResponseError, generation_types.IncompleteIterationError): + history = list(self._history) + + if self._last_sent is not None: + history.append(self._last_sent) + history = [content_repr(x) for x in history] + + last_received = self._last_received + if last_received is not None: + if last_received._error is not None: + history.append("") + else: + history.append("") + + _history = ",\n " + f"history=[{', '.join(history)}]\n)" + + return ( + textwrap.dedent( + f"""\ + ChatSession( + model=""" + ) + + _model + + _history + ) diff --git a/.venv/lib/python3.11/site-packages/google/generativeai/models.py b/.venv/lib/python3.11/site-packages/google/generativeai/models.py new file mode 100644 index 0000000000000000000000000000000000000000..b23a7ce8805268a3864a6401cad31a6c95e1252c --- /dev/null +++ b/.venv/lib/python3.11/site-packages/google/generativeai/models.py @@ -0,0 +1,467 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import typing +from typing import Any, Literal + +import google.ai.generativelanguage as glm + +from google.generativeai import protos +from google.generativeai import operations +from google.generativeai.client import get_default_model_client +from google.generativeai.types import model_types +from google.generativeai.types import helper_types +from google.api_core import operation +from google.api_core import protobuf_helpers +from google.protobuf import field_mask_pb2 +from google.generativeai.utils import flatten_update_paths + + +def get_model( + name: model_types.AnyModelNameOptions, + *, + client=None, + request_options: helper_types.RequestOptionsType | None = None, +) -> model_types.Model | model_types.TunedModel: + """Calls the API to fetch a model by name. + + ``` + import pprint + model = genai.get_model('models/gemini-1.5-flash') + pprint.pprint(model) + ``` + + Args: + name: The name of the model to fetch. Should start with `models/` + client: The client to use. + request_options: Options for the request. + + Returns: + A `types.Model` + """ + name = model_types.make_model_name(name) + if name.startswith("models/"): + return get_base_model(name, client=client, request_options=request_options) + elif name.startswith("tunedModels/"): + return get_tuned_model(name, client=client, request_options=request_options) + else: + raise ValueError( + f"Invalid model name: Model names must start with `models/` or `tunedModels/`. Received: {name}" + ) + + +def get_base_model( + name: model_types.BaseModelNameOptions, + *, + client=None, + request_options: helper_types.RequestOptionsType | None = None, +) -> model_types.Model: + """Calls the API to fetch a base model by name. + + ``` + import pprint + model = genai.get_base_model('models/chat-bison-001') + pprint.pprint(model) + ``` + + Args: + name: The name of the model to fetch. Should start with `models/` + client: The client to use. + request_options: Options for the request. + + Returns: + A `types.Model`. + """ + if request_options is None: + request_options = {} + + if client is None: + client = get_default_model_client() + + name = model_types.make_model_name(name) + if not name.startswith("models/"): + raise ValueError( + f"Invalid model name: Base model names must start with `models/`. Received: {name}" + ) + + result = client.get_model(name=name, **request_options) + result = type(result).to_dict(result) + return model_types.Model(**result) + + +def get_tuned_model( + name: model_types.TunedModelNameOptions, + *, + client=None, + request_options: helper_types.RequestOptionsType | None = None, +) -> model_types.TunedModel: + """Calls the API to fetch a tuned model by name. + + ``` + import pprint + model = genai.get_tuned_model('tunedModels/gemini-1.5-flash') + pprint.pprint(model) + ``` + + Args: + name: The name of the model to fetch. Should start with `tunedModels/` + client: The client to use. + request_options: Options for the request. + + Returns: + A `types.TunedModel`. + """ + if request_options is None: + request_options = {} + + if client is None: + client = get_default_model_client() + + name = model_types.make_model_name(name) + + if not name.startswith("tunedModels/"): + raise ValueError( + f"Invalid model name: Tuned model names must start with `tunedModels/`. Received: {name}" + ) + + result = client.get_tuned_model(name=name, **request_options) + + return model_types.decode_tuned_model(result) + + +def get_base_model_name( + model: model_types.AnyModelNameOptions, client: glm.ModelServiceClient | None = None +): + """Calls the API to fetch the base model name of a model.""" + + if isinstance(model, str): + if model.startswith("tunedModels/"): + model = get_model(model, client=client) + base_model = model.base_model + else: + base_model = model + elif isinstance(model, model_types.TunedModel): + base_model = model.base_model + elif isinstance(model, model_types.Model): + base_model = model.name + elif isinstance(model, protos.Model): + base_model = model.name + elif isinstance(model, protos.TunedModel): + base_model = getattr(model, "base_model", None) + if not base_model: + base_model = model.tuned_model_source.base_model + else: + raise TypeError( + f"Invalid model: The provided model '{model}' is not recognized or supported. " + "Supported types are: str, model_types.TunedModel, model_types.Model, protos.Model, and protos.TunedModel." + ) + + return base_model + + +def list_models( + *, + page_size: int | None = 50, + client: glm.ModelServiceClient | None = None, + request_options: helper_types.RequestOptionsType | None = None, +) -> model_types.ModelsIterable: + """Calls the API to list all available models. + + ``` + import pprint + for model in genai.list_models(): + pprint.pprint(model) + ``` + + Args: + page_size: How many `types.Models` to fetch per page (api call). + client: You may pass a `glm.ModelServiceClient` instead of using the default client. + request_options: Options for the request. + + Yields: + `types.Model` objects. + + """ + if request_options is None: + request_options = {} + + if client is None: + client = get_default_model_client() + + for model in client.list_models(page_size=page_size, **request_options): + model = type(model).to_dict(model) + yield model_types.Model(**model) + + +def list_tuned_models( + *, + page_size: int | None = 50, + client: glm.ModelServiceClient | None = None, + request_options: helper_types.RequestOptionsType | None = None, +) -> model_types.TunedModelsIterable: + """Calls the API to list all tuned models. + + ``` + import pprint + for model in genai.list_tuned_models(): + pprint.pprint(model) + ``` + + Args: + page_size: How many `types.Models` to fetch per page (api call). + client: You may pass a `glm.ModelServiceClient` instead of using the default client. + request_options: Options for the request. + + Yields: + `types.TunedModel` objects. + """ + if request_options is None: + request_options = {} + + if client is None: + client = get_default_model_client() + + for model in client.list_tuned_models( + page_size=page_size, + **request_options, + ): + model = type(model).to_dict(model) + yield model_types.decode_tuned_model(model) + + +def create_tuned_model( + source_model: model_types.AnyModelNameOptions, + training_data: model_types.TuningDataOptions, + *, + id: str | None = None, + display_name: str | None = None, + description: str | None = None, + temperature: float | None = None, + top_p: float | None = None, + top_k: int | None = None, + epoch_count: int | None = None, + batch_size: int | None = None, + learning_rate: float | None = None, + input_key: str = "text_input", + output_key: str = "output", + client: glm.ModelServiceClient | None = None, + request_options: helper_types.RequestOptionsType | None = None, +) -> operations.CreateTunedModelOperation: + """Calls the API to initiate a tuning process that optimizes a model for specific data, returning an operation object to track and manage the tuning progress. + + Since tuning a model can take significant time, this API doesn't wait for the tuning to complete. + Instead, it returns a `google.api_core.operation.Operation` object that lets you check on the + status of the tuning job, or wait for it to complete, and check the result. + + After the job completes you can either find the resulting `TunedModel` object in + `Operation.result()` or `palm.list_tuned_models` or `palm.get_tuned_model(model_id)`. + + ``` + my_id = "my-tuned-model-id" + operation = palm.create_tuned_model( + id = my_id, + source_model="models/text-bison-001", + training_data=[{'text_input': 'example input', 'output': 'example output'},...] + ) + tuned_model=operation.result() # Wait for tuning to finish + + palm.generate_text(f"tunedModels/{my_id}", prompt="...") + ``` + + Args: + source_model: The name of the model to tune. + training_data: The dataset to tune the model on. This must be either: + * A `protos.Dataset`, or + * An `Iterable` of: + *`protos.TuningExample`, + * `{'text_input': text_input, 'output': output}` dicts + * `(text_input, output)` tuples. + * A `Mapping` of `Iterable[str]` - use `input_key` and `output_key` to choose which + columns to use as the input/output + * A csv file (will be read with `pd.read_csv` and handles as a `Mapping` + above). This can be: + * A local path as a `str` or `pathlib.Path`. + * A url for a csv file. + * The url of a Google Sheets file. + * A JSON file - Its contents will be handled either as an `Iterable` or `Mapping` + above. This can be: + * A local path as a `str` or `pathlib.Path`. + id: The model identifier, used to refer to the model in the API + `tunedModels/{id}`. Must be unique. + display_name: A human-readable name for display. + description: A description of the tuned model. + temperature: The default temperature for the tuned model, see `types.Model` for details. + top_p: The default `top_p` for the model, see `types.Model` for details. + top_k: The default `top_k` for the model, see `types.Model` for details. + epoch_count: The number of tuning epochs to run. An epoch is a pass over the whole dataset. + batch_size: The number of examples to use in each training batch. + learning_rate: The step size multiplier for the gradient updates. + client: Which client to use. + request_options: Options for the request. + + Returns: + A [`google.api_core.operation.Operation`](https://googleapis.dev/python/google-api-core/latest/operation.html) + """ + if request_options is None: + request_options = {} + + if client is None: + client = get_default_model_client() + + source_model_name = model_types.make_model_name(source_model) + base_model_name = get_base_model_name(source_model) + if source_model_name.startswith("models/"): + source_model = {"base_model": source_model_name} + elif source_model_name.startswith("tunedModels/"): + source_model = { + "tuned_model_source": { + "tuned_model": source_model_name, + "base_model": base_model_name, + } + } + else: + raise ValueError( + f"Invalid model name: The provided model '{source_model}' does not match any known model patterns such as 'models/' or 'tunedModels/'" + ) + + training_data = model_types.encode_tuning_data( + training_data, input_key=input_key, output_key=output_key + ) + + hyperparameters = protos.Hyperparameters( + epoch_count=epoch_count, + batch_size=batch_size, + learning_rate=learning_rate, + ) + tuning_task = protos.TuningTask( + training_data=training_data, + hyperparameters=hyperparameters, + ) + + tuned_model = protos.TunedModel( + **source_model, + display_name=display_name, + description=description, + temperature=temperature, + top_p=top_p, + top_k=top_k, + tuning_task=tuning_task, + ) + + operation = client.create_tuned_model( + dict(tuned_model_id=id, tuned_model=tuned_model), **request_options + ) + + return operations.CreateTunedModelOperation.from_core_operation(operation) + + +@typing.overload +def update_tuned_model( + tuned_model: protos.TunedModel, + updates: None = None, + *, + client: glm.ModelServiceClient | None = None, + request_options: helper_types.RequestOptionsType | None = None, +) -> model_types.TunedModel: + pass + + +@typing.overload +def update_tuned_model( + tuned_model: str, + updates: dict[str, Any], + *, + client: glm.ModelServiceClient | None = None, + request_options: helper_types.RequestOptionsType | None = None, +) -> model_types.TunedModel: + pass + + +def update_tuned_model( + tuned_model: str | protos.TunedModel, + updates: dict[str, Any] | None = None, + *, + client: glm.ModelServiceClient | None = None, + request_options: helper_types.RequestOptionsType | None = None, +) -> model_types.TunedModel: + """Calls the API to push updates to a specified tuned model where only certain attributes are updatable.""" + + if request_options is None: + request_options = {} + + if client is None: + client = get_default_model_client() + + if isinstance(tuned_model, str): + name = tuned_model + if not isinstance(updates, dict): + raise TypeError( + f"Invalid argument type: In the function `update_tuned_model(name:str, updates: dict)`, the `updates` argument must be of type `dict`. Received type: {type(updates).__name__}." + ) + + tuned_model = client.get_tuned_model(name=name, **request_options) + + updates = flatten_update_paths(updates) + field_mask = field_mask_pb2.FieldMask() + for path in updates.keys(): + field_mask.paths.append(path) + for path, value in updates.items(): + _apply_update(tuned_model, path, value) + elif isinstance(tuned_model, protos.TunedModel): + if updates is not None: + raise ValueError( + "Invalid argument: When calling `update_tuned_model(tuned_model:protos.TunedModel, updates=None)`, " + "the `updates` argument must not be set." + ) + + name = tuned_model.name + was = client.get_tuned_model(name=name) + field_mask = protobuf_helpers.field_mask(was._pb, tuned_model._pb) + else: + raise TypeError( + "Invalid argument type: In the function `update_tuned_model(tuned_model:dict|protos.TunedModel)`, the " + f"`tuned_model` argument must be of type `dict` or `protos.TunedModel`. Received type: {type(tuned_model).__name__}." + ) + + result = client.update_tuned_model( + protos.UpdateTunedModelRequest(tuned_model=tuned_model, update_mask=field_mask), + **request_options, + ) + return model_types.decode_tuned_model(result) + + +def _apply_update(thing, path, value): + parts = path.split(".") + for part in parts[:-1]: + thing = getattr(thing, part) + setattr(thing, parts[-1], value) + + +def delete_tuned_model( + tuned_model: model_types.TunedModelNameOptions, + client: glm.ModelServiceClient | None = None, + request_options: helper_types.RequestOptionsType | None = None, +) -> None: + """Calls the API to delete a specified tuned model""" + + if request_options is None: + request_options = {} + + if client is None: + client = get_default_model_client() + + name = model_types.make_model_name(tuned_model) + client.delete_tuned_model(name=name, **request_options) diff --git a/.venv/lib/python3.11/site-packages/google/generativeai/notebook/__init__.py b/.venv/lib/python3.11/site-packages/google/generativeai/notebook/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..16ccb2fdfb395489194884bf086fe999c9b232d9 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/google/generativeai/notebook/__init__.py @@ -0,0 +1,32 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Notebook extensions for Generative AI.""" + + +def load_ipython_extension(ipython): + """Register the Colab Magic extension to support %load_ext.""" + # pylint: disable-next=g-import-not-at-top + from google.generativeai.notebook import magics + + ipython.register_magics(magics.Magics) + + # Since we're in an interactive environment, make the tables prettier. + try: + # pylint: disable-next=g-import-not-at-top + from google import colab # type: ignore + + colab.data_table.enable_dataframe_formatter() + except ImportError: + pass diff --git a/.venv/lib/python3.11/site-packages/google/generativeai/notebook/__pycache__/command.cpython-311.pyc b/.venv/lib/python3.11/site-packages/google/generativeai/notebook/__pycache__/command.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bef5685c5e1641c8d9aa21f3c4a2b80ff338ed90 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/google/generativeai/notebook/__pycache__/command.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/google/generativeai/notebook/__pycache__/compare_cmd.cpython-311.pyc b/.venv/lib/python3.11/site-packages/google/generativeai/notebook/__pycache__/compare_cmd.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5f5fab844dc7b82292a2d7ff7548883ccc22a36f Binary files /dev/null and b/.venv/lib/python3.11/site-packages/google/generativeai/notebook/__pycache__/compare_cmd.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/google/generativeai/notebook/__pycache__/eval_cmd.cpython-311.pyc b/.venv/lib/python3.11/site-packages/google/generativeai/notebook/__pycache__/eval_cmd.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..264bc8d863d285119213134b805ae469601c9ee7 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/google/generativeai/notebook/__pycache__/eval_cmd.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/google/generativeai/notebook/__pycache__/flag_def.cpython-311.pyc b/.venv/lib/python3.11/site-packages/google/generativeai/notebook/__pycache__/flag_def.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..94972fd3be2c6c3a1a581ca094339234dce31a6a Binary files /dev/null and b/.venv/lib/python3.11/site-packages/google/generativeai/notebook/__pycache__/flag_def.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/google/generativeai/notebook/__pycache__/input_utils.cpython-311.pyc b/.venv/lib/python3.11/site-packages/google/generativeai/notebook/__pycache__/input_utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ca72c516637f09493095dfb5d5ce488f69826a24 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/google/generativeai/notebook/__pycache__/input_utils.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/google/generativeai/notebook/__pycache__/magics_engine.cpython-311.pyc b/.venv/lib/python3.11/site-packages/google/generativeai/notebook/__pycache__/magics_engine.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..43e1d2508f43d1dd4d8546d9cf572f5589329826 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/google/generativeai/notebook/__pycache__/magics_engine.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/google/generativeai/notebook/__pycache__/parsed_args_lib.cpython-311.pyc b/.venv/lib/python3.11/site-packages/google/generativeai/notebook/__pycache__/parsed_args_lib.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d2e783d22b4209e2867c1fa6f4675256cd76f8ef Binary files /dev/null and b/.venv/lib/python3.11/site-packages/google/generativeai/notebook/__pycache__/parsed_args_lib.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/google/generativeai/notebook/__pycache__/run_cmd.cpython-311.pyc b/.venv/lib/python3.11/site-packages/google/generativeai/notebook/__pycache__/run_cmd.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e2f10e2cbec071698cc14f8dcfcbd4c85f4ee271 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/google/generativeai/notebook/__pycache__/run_cmd.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/google/generativeai/notebook/__pycache__/sheets_id.cpython-311.pyc b/.venv/lib/python3.11/site-packages/google/generativeai/notebook/__pycache__/sheets_id.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..637e489a43c9f2a70396d27059f3a58d3cbe8319 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/google/generativeai/notebook/__pycache__/sheets_id.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/google/generativeai/notebook/__pycache__/sheets_utils.cpython-311.pyc b/.venv/lib/python3.11/site-packages/google/generativeai/notebook/__pycache__/sheets_utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..88720212abd4273fc24132b1b0ba969149a0b8ca Binary files /dev/null and b/.venv/lib/python3.11/site-packages/google/generativeai/notebook/__pycache__/sheets_utils.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/google/generativeai/notebook/cmd_line_parser.py b/.venv/lib/python3.11/site-packages/google/generativeai/notebook/cmd_line_parser.py new file mode 100644 index 0000000000000000000000000000000000000000..9b8e84048a4e2ecd564f3f1ace560f6c802a6d83 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/google/generativeai/notebook/cmd_line_parser.py @@ -0,0 +1,557 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Parses an LLM command line.""" +from __future__ import annotations + +import argparse +import shlex +import sys +from typing import AbstractSet, Any, Callable, MutableMapping, Sequence + +from google.generativeai.notebook import argument_parser +from google.generativeai.notebook import flag_def +from google.generativeai.notebook import input_utils +from google.generativeai.notebook import model_registry +from google.generativeai.notebook import output_utils +from google.generativeai.notebook import parsed_args_lib +from google.generativeai.notebook import post_process_utils +from google.generativeai.notebook import py_utils +from google.generativeai.notebook import sheets_utils +from google.generativeai.notebook.lib import llm_function +from google.generativeai.notebook.lib import llmfn_inputs_source +from google.generativeai.notebook.lib import llmfn_outputs +from google.generativeai.notebook.lib import model as model_lib + + +_MIN_CANDIDATE_COUNT = 1 +_MAX_CANDIDATE_COUNT = 8 + + +def _validate_input_source_against_placeholders( + source: llmfn_inputs_source.LLMFnInputsSource, + placeholders: AbstractSet[str], +) -> None: + for inputs in source.to_normalized_inputs(): + for keyword in placeholders: + if keyword not in inputs: + raise ValueError('Placeholder "{}" not found in input'.format(keyword)) + + +def _get_resolve_input_from_py_var_fn( + placeholders: AbstractSet[str] | None, +) -> Callable[[str], llmfn_inputs_source.LLMFnInputsSource]: + def _fn(var_name: str) -> llmfn_inputs_source.LLMFnInputsSource: + source = input_utils.get_inputs_source_from_py_var(var_name) + if placeholders: + _validate_input_source_against_placeholders(source, placeholders) + return source + + return _fn + + +def _resolve_compare_fn_var( + name: str, +) -> tuple[str, parsed_args_lib.TextResultCompareFn]: + """Resolves a value passed into --compare_fn.""" + fn = py_utils.get_py_var(name) + if not isinstance(fn, Callable): + raise ValueError('Variable "{}" does not contain a Callable object'.format(name)) + + return name, fn + + +def _resolve_ground_truth_var(name: str) -> Sequence[str]: + """Resolves a value passed into --ground_truth.""" + value = py_utils.get_py_var(name) + + # "str" and "bytes" are also Sequences but we want an actual Sequence of + # strings, like a list. + if not isinstance(value, Sequence) or isinstance(value, str) or isinstance(value, bytes): + raise ValueError('Variable "{}" does not contain a Sequence of strings'.format(name)) + for x in value: + if not isinstance(x, str): + raise ValueError('Variable "{}" does not contain a Sequence of strings'.format(name)) + return value + + +def _get_resolve_sheets_inputs_fn( + placeholders: AbstractSet[str] | None, +) -> Callable[[str], llmfn_inputs_source.LLMFnInputsSource]: + def _fn(value: str) -> llmfn_inputs_source.LLMFnInputsSource: + sheets_id = sheets_utils.get_sheets_id_from_str(value) + source = sheets_utils.SheetsInputs(sheets_id) + if placeholders: + _validate_input_source_against_placeholders(source, placeholders) + return source + + return _fn + + +def _resolve_sheets_outputs(value: str) -> llmfn_outputs.LLMFnOutputsSink: + sheets_id = sheets_utils.get_sheets_id_from_str(value) + return sheets_utils.SheetsOutputs(sheets_id) + + +def _add_model_flags( + parser: argparse.ArgumentParser, +) -> None: + """Adds flags that are related to model selection and config.""" + flag_def.EnumFlagDef( + name="model_type", + short_name="mt", + enum_type=model_registry.ModelName, + default_value=model_registry.ModelRegistry.DEFAULT_MODEL, + help_msg="The type of model to use.", + ).add_argument_to_parser(parser) + + def _check_is_greater_than_or_equal_to_zero(x: float) -> float: + if x < 0: + raise ValueError("Value should be greater than or equal to zero, got {}".format(x)) + return x + + flag_def.SingleValueFlagDef( + name="temperature", + short_name="t", + parse_type=float, + # Use None for default value to indicate that this will use the default + # value in Text service. + default_value=None, + parse_to_dest_type_fn=_check_is_greater_than_or_equal_to_zero, + help_msg=( + "Controls the randomness of the output. Must be positive. Typical" + " values are in the range: [0.0, 1.0]. Higher values produce a more" + " random and varied response. A temperature of zero will be" + " deterministic." + ), + ).add_argument_to_parser(parser) + + flag_def.SingleValueFlagDef( + name="model", + short_name="m", + default_value=None, + help_msg=( + "The name of the model to use. If not provided, a default model will" " be used." + ), + ).add_argument_to_parser(parser) + + def _check_candidate_count_range(x: Any) -> int: + if x < _MIN_CANDIDATE_COUNT or x > _MAX_CANDIDATE_COUNT: + raise ValueError( + "Value should be in the range [{}, {}], got {}".format( + _MIN_CANDIDATE_COUNT, _MAX_CANDIDATE_COUNT, x + ) + ) + return int(x) + + flag_def.SingleValueFlagDef( + name="candidate_count", + short_name="cc", + parse_type=int, + # Use None for default value to indicate that this will use the default + # value in Text service. + default_value=None, + parse_to_dest_type_fn=_check_candidate_count_range, + help_msg="The number of candidates to produce.", + ).add_argument_to_parser(parser) + + flag_def.BooleanFlagDef( + name="unique", + help_msg="Whether to dedupe candidates returned by the model.", + ).add_argument_to_parser(parser) + + +def _add_input_flags( + parser: argparse.ArgumentParser, + placeholders: AbstractSet[str] | None, +) -> None: + """Adds flags to read inputs from a Python variable or Sheets.""" + flag_def.MultiValuesFlagDef( + name="inputs", + short_name="i", + dest_type=llmfn_inputs_source.LLMFnInputsSource, + parse_to_dest_type_fn=_get_resolve_input_from_py_var_fn(placeholders), + help_msg=( + "Optional names of Python variables containing inputs to use to" + " instantiate a prompt. The variable must be either: a dictionary" + " {'key1': ['val1', 'val2'] ...}, or an instance of LLMFnInputsSource" + " such as SheetsInput." + ), + ).add_argument_to_parser(parser) + + flag_def.MultiValuesFlagDef( + name="sheets_input_names", + short_name="si", + dest_type=llmfn_inputs_source.LLMFnInputsSource, + parse_to_dest_type_fn=_get_resolve_sheets_inputs_fn(placeholders), + help_msg=( + "Optional names of Google Sheets to read inputs from. This is" + " equivalent to using --inputs with the names of variables that are" + " instances of SheetsInputs, just more convenient to use." + ), + ).add_argument_to_parser(parser) + + +def _add_output_flags( + parser: argparse.ArgumentParser, +) -> None: + """Adds flags to write outputs to a Python variable.""" + flag_def.MultiValuesFlagDef( + name="outputs", + short_name="o", + dest_type=llmfn_outputs.LLMFnOutputsSink, + parse_to_dest_type_fn=output_utils.get_outputs_sink_from_py_var, + help_msg=( + "Optional names of Python variables to output to. If the Python" + " variable has not already been defined, it will be created. If the" + " variable is defined and is an instance of LLMFnOutputsSink, the" + " outputs will be written through the sink's write_outputs() method." + ), + ).add_argument_to_parser(parser) + + flag_def.MultiValuesFlagDef( + name="sheets_output_names", + short_name="so", + dest_type=llmfn_outputs.LLMFnOutputsSink, + parse_to_dest_type_fn=_resolve_sheets_outputs, + help_msg=( + "Optional names of Google Sheets to write inputs to. This is" + " equivalent to using --outputs with the names of variables that are" + " instances of SheetsOutputs, just more convenient to use." + ), + ).add_argument_to_parser(parser) + + +def _add_compare_flags( + parser: argparse.ArgumentParser, +) -> None: + flag_def.MultiValuesFlagDef( + name="compare_fn", + dest_type=tuple, + parse_to_dest_type_fn=_resolve_compare_fn_var, + help_msg=( + "An optional function that takes two inputs: (lhs_result, rhs_result)" + " which are the results of the left- and right-hand side functions. " + "Multiple comparison functions can be provided." + ), + ).add_argument_to_parser(parser) + + +def _add_eval_flags( + parser: argparse.ArgumentParser, +) -> None: + flag_def.SingleValueFlagDef( + name="ground_truth", + required=True, + dest_type=Sequence, + parse_to_dest_type_fn=_resolve_ground_truth_var, + help_msg=( + "A variable containing a Sequence of strings representing the ground" + " truth that the output of this cell will be compared against. It" + " should have the same number of entries as inputs." + ), + ).add_argument_to_parser(parser) + + +def _create_run_parser( + parser: argparse.ArgumentParser, + placeholders: AbstractSet[str] | None, +) -> None: + """Adds flags for the `run` command. + + `run` sends one or more prompts to a model. + + Args: + parser: The parser to which flags will be added. + placeholders: Placeholders from prompts in the cell contents. + """ + _add_model_flags(parser) + _add_input_flags(parser, placeholders) + _add_output_flags(parser) + + +def _create_compile_parser( + parser: argparse.ArgumentParser, +) -> None: + """Adds flags for the compile command. + + `compile` "compiles" a prompt and model call into a callable function. + + Args: + parser: The parser to which flags will be added. + """ + + # Add a positional argument for "compile_save_name". + def _compile_save_name_fn(var_name: str) -> str: + try: + py_utils.validate_var_name(var_name) + except ValueError as e: + # Re-raise as ArgumentError to preserve the original error message. + raise argparse.ArgumentError(None, "{}".format(e)) from e + return var_name + + save_name_help = "The name of a Python variable to save the compiled function to." + parser.add_argument("compile_save_name", help=save_name_help, type=_compile_save_name_fn) + _add_model_flags(parser) + + +def _create_compare_parser( + parser: argparse.ArgumentParser, + placeholders: AbstractSet[str] | None, +) -> None: + """Adds flags for the compare command. + + Args: + parser: The parser to which flags will be added. + placeholders: Placeholders from prompts in the compiled functions. + """ + + # Add positional arguments. + def _resolve_llm_function_fn( + var_name: str, + ) -> tuple[str, llm_function.LLMFunction]: + try: + py_utils.validate_var_name(var_name) + except ValueError as e: + # Re-raise as ArgumentError to preserve the original error message. + raise argparse.ArgumentError(None, "{}".format(e)) from e + + fn = py_utils.get_py_var(var_name) + if not isinstance(fn, llm_function.LLMFunction): + raise argparse.ArgumentError( + None, + '{} is not a function created with the "compile" command'.format(var_name), + ) + return var_name, fn + + name_help = ( + "The name of a Python variable containing a function previously created" + ' with the "compile" command.' + ) + parser.add_argument("lhs_name_and_fn", help=name_help, type=_resolve_llm_function_fn) + parser.add_argument("rhs_name_and_fn", help=name_help, type=_resolve_llm_function_fn) + + _add_input_flags(parser, placeholders) + _add_output_flags(parser) + _add_compare_flags(parser) + + +def _create_eval_parser( + parser: argparse.ArgumentParser, + placeholders: AbstractSet[str] | None, +) -> None: + """Adds flags for the eval command. + + Args: + parser: The parser to which flags will be added. + placeholders: Placeholders from prompts in the cell contents. + """ + _add_model_flags(parser) + _add_input_flags(parser, placeholders) + _add_output_flags(parser) + _add_compare_flags(parser) + _add_eval_flags(parser) + + +def _create_parser( + placeholders: AbstractSet[str] | None, +) -> argparse.ArgumentParser: + """Create the full parser.""" + system_name = "llm" + description = "A system for interacting with LLMs." + epilog = "" + + # Commands + parser = argument_parser.ArgumentParser( + prog=system_name, + description=description, + epilog=epilog, + ) + subparsers = parser.add_subparsers(dest="cmd") + _create_run_parser( + subparsers.add_parser(parsed_args_lib.CommandName.RUN_CMD.value), + placeholders, + ) + _create_compile_parser(subparsers.add_parser(parsed_args_lib.CommandName.COMPILE_CMD.value)) + _create_compare_parser( + subparsers.add_parser(parsed_args_lib.CommandName.COMPARE_CMD.value), + placeholders, + ) + _create_eval_parser( + subparsers.add_parser(parsed_args_lib.CommandName.EVAL_CMD.value), + placeholders, + ) + return parser + + +def _validate_parsed_args(parsed_args: parsed_args_lib.ParsedArgs) -> None: + # If candidate_count is not set (i.e. is None), assuming the default value + # is 1. + if parsed_args.unique and ( + parsed_args.model_args.candidate_count is None + or parsed_args.model_args.candidate_count == 1 + ): + print( + '"--unique" works across candidates only: it should be used with' + " --candidate_count set to a value greater-than one." + ) + + +class CmdLineParser: + """Implementation of Magics command line parser.""" + + # Commands + DEFAULT_CMD = parsed_args_lib.CommandName.RUN_CMD + + # Post-processing operator. + PIPE_OP = "|" + + @classmethod + def _split_post_processing_tokens( + cls, + tokens: Sequence[str], + ) -> tuple[Sequence[str], parsed_args_lib.PostProcessingTokens]: + """Splits inputs into the command and post processing tokens. + + The command is represented as a sequence of tokens. + See comments on the PostProcessingTokens type alias. + + E.g. Given: "run --temperature 0.5 | add_score | to_lower_case" + The command will be: ["run", "--temperature", "0.5"]. + The post processing tokens will be: [["add_score"], ["to_lower_case"]] + + Args: + tokens: The command line tokens. + + Returns: + A tuple of (command line, post processing tokens). + """ + split_tokens = [] + start_idx: int | None = None + for token_num, token in enumerate(tokens): + if start_idx is None: + start_idx = token_num + if token == CmdLineParser.PIPE_OP: + split_tokens.append(tokens[start_idx:token_num] if start_idx is not None else []) + start_idx = None + + # Add the remaining tokens after the last PIPE_OP. + split_tokens.append(tokens[start_idx:] if start_idx is not None else []) + + return split_tokens[0], split_tokens[1:] + + @classmethod + def _tokenize_line( + cls, line: str + ) -> tuple[Sequence[str], parsed_args_lib.PostProcessingTokens]: + """Parses `line` and returns command line and post processing tokens.""" + # Check to make sure there is a command at the start. If not, add the + # default command to the list of tokens. + tokens = shlex.split(line) + if not tokens: + tokens = [CmdLineParser.DEFAULT_CMD.value] + first_token = tokens[0] + # Add default command if the first token is not the help token. + if not first_token[0].isalpha() and first_token not in ["-h", "--help"]: + tokens = [CmdLineParser.DEFAULT_CMD.value] + tokens + # Split line into tokens and post-processing + return CmdLineParser._split_post_processing_tokens(tokens) + + @classmethod + def _get_model_args( + cls, parsed_results: MutableMapping[str, Any] + ) -> tuple[MutableMapping[str, Any], model_lib.ModelArguments]: + """Extracts fields for model args from `parsed_results`. + + Keys specific to model arguments will be removed from `parsed_results`. + + Args: + parsed_results: A dictionary of parsed arguments (from ArgumentParser). It + will be modified in place. + + Returns: + A tuple of (updated parsed_results, model arguments). + """ + model = parsed_results.pop("model", None) + temperature = parsed_results.pop("temperature", None) + candidate_count = parsed_results.pop("candidate_count", None) + + model_args = model_lib.ModelArguments( + model=model, + temperature=temperature, + candidate_count=candidate_count, + ) + return parsed_results, model_args + + def parse_line( + self, + line: str, + placeholders: AbstractSet[str] | None = None, + ) -> tuple[parsed_args_lib.ParsedArgs, parsed_args_lib.PostProcessingTokens]: + """Parses the commandline and returns ParsedArgs and post-processing tokens. + + Args: + line: The line to parse (usually contents from cell Magics). + placeholders: Placeholders from prompts in the cell contents. + + Returns: + A tuple of (parsed_args, post_processing_tokens). + """ + tokens, post_processing_tokens = CmdLineParser._tokenize_line(line) + + parsed_args = self._get_parsed_args_from_cmd_line_tokens( + tokens=tokens, placeholders=placeholders + ) + + # Special-case for "compare" command: because the prompts are compiled into + # the left- and right-hand side functions rather than in the cell body, we + # cannot examine the cell body to get the placeholders. + # + # Instead we parse the command line twice: once to get the left- and right- + # functions, then we query the functions for their placeholders, then + # parse the commandline again to validate the inputs. + if parsed_args.cmd == parsed_args_lib.CommandName.COMPARE_CMD: + assert parsed_args.lhs_name_and_fn is not None + assert parsed_args.rhs_name_and_fn is not None + _, lhs_fn = parsed_args.lhs_name_and_fn + _, rhs_fn = parsed_args.rhs_name_and_fn + parsed_args = self._get_parsed_args_from_cmd_line_tokens( + tokens=tokens, + placeholders=frozenset(lhs_fn.get_placeholders()).union(rhs_fn.get_placeholders()), + ) + + _validate_parsed_args(parsed_args) + + for expr in post_processing_tokens: + post_process_utils.validate_one_post_processing_expression(expr) + + return parsed_args, post_processing_tokens + + def _get_parsed_args_from_cmd_line_tokens( + self, + tokens: Sequence[str], + placeholders: AbstractSet[str] | None, + ) -> parsed_args_lib.ParsedArgs: + """Returns ParsedArgs from a tokenized command line.""" + # Create a new parser to avoid reusing the temporary argparse.Namespace + # object. + results = _create_parser(placeholders).parse_args(tokens) + + results_dict = vars(results) + results_dict["cmd"] = parsed_args_lib.CommandName(results_dict["cmd"]) + + results_dict, model_args = CmdLineParser._get_model_args(results_dict) + results_dict["model_args"] = model_args + + return parsed_args_lib.ParsedArgs(**results_dict) diff --git a/.venv/lib/python3.11/site-packages/google/generativeai/notebook/command.py b/.venv/lib/python3.11/site-packages/google/generativeai/notebook/command.py new file mode 100644 index 0000000000000000000000000000000000000000..df261326260d06cfdde896d657330834147eb79b --- /dev/null +++ b/.venv/lib/python3.11/site-packages/google/generativeai/notebook/command.py @@ -0,0 +1,45 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Command.""" +from __future__ import annotations + +import abc +import collections +from typing import Sequence + +from google.generativeai.notebook import parsed_args_lib +from google.generativeai.notebook import post_process_utils + + +ProcessingCommand = collections.namedtuple("ProcessingCommand", ["name", "fn"]) + + +class Command(abc.ABC): + """Base class for implementation of Magics commands like "run".""" + + @abc.abstractmethod + def execute( + self, + parsed_args: parsed_args_lib.ParsedArgs, + cell_content: str, + post_processing_fns: Sequence[post_process_utils.ParsedPostProcessExpr], + ): + """Executes the command given `parsed_args` and the `cell_content`.""" + + @abc.abstractmethod + def parse_post_processing_tokens( + self, tokens: Sequence[Sequence[str]] + ) -> Sequence[post_process_utils.ParsedPostProcessExpr]: + """Parses post-processing tokens for this command.""" diff --git a/.venv/lib/python3.11/site-packages/google/generativeai/notebook/compare_cmd.py b/.venv/lib/python3.11/site-packages/google/generativeai/notebook/compare_cmd.py new file mode 100644 index 0000000000000000000000000000000000000000..bea61e17f23ef77312c77d0dd0462b84971510ad --- /dev/null +++ b/.venv/lib/python3.11/site-packages/google/generativeai/notebook/compare_cmd.py @@ -0,0 +1,71 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""The compare command.""" +from __future__ import annotations + +from typing import Sequence + +from google.generativeai.notebook import command +from google.generativeai.notebook import command_utils +from google.generativeai.notebook import input_utils +from google.generativeai.notebook import ipython_env +from google.generativeai.notebook import output_utils +from google.generativeai.notebook import parsed_args_lib +from google.generativeai.notebook import post_process_utils +import pandas + + +class CompareCommand(command.Command): + """Implementation of "compare" command.""" + + def __init__( + self, + env: ipython_env.IPythonEnv | None = None, + ): + """Constructor. + + Args: + env: The IPythonEnv environment. + """ + super().__init__() + self._ipython_env = env + + def execute( + self, + parsed_args: parsed_args_lib.ParsedArgs, + cell_content: str, + post_processing_fns: Sequence[post_process_utils.ParsedPostProcessExpr], + ) -> pandas.DataFrame: + # We expect CmdLineParser to have already read the inputs once to validate + # that the placeholders in the prompt are present in the inputs, so we can + # suppress the status messages here. + inputs = input_utils.join_inputs_sources(parsed_args, suppress_status_msgs=True) + + llm_cmp_fn = command_utils.create_llm_compare_function( + env=self._ipython_env, + parsed_args=parsed_args, + post_processing_fns=post_processing_fns, + ) + + results = llm_cmp_fn(inputs=inputs) + output_utils.write_to_outputs(results=results, parsed_args=parsed_args) + return results.as_pandas_dataframe() + + def parse_post_processing_tokens( + self, tokens: Sequence[Sequence[str]] + ) -> Sequence[post_process_utils.ParsedPostProcessExpr]: + if tokens: + raise RuntimeError('Post-processing is not supported by "compare"') + return [] diff --git a/.venv/lib/python3.11/site-packages/google/generativeai/notebook/compile_cmd.py b/.venv/lib/python3.11/site-packages/google/generativeai/notebook/compile_cmd.py new file mode 100644 index 0000000000000000000000000000000000000000..e06401f3c670d5e916eb6e9dc43d82c0b94e14d3 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/google/generativeai/notebook/compile_cmd.py @@ -0,0 +1,66 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""The compile command.""" +from __future__ import annotations + +from typing import Sequence +from google.generativeai.notebook import command +from google.generativeai.notebook import command_utils +from google.generativeai.notebook import ipython_env +from google.generativeai.notebook import model_registry +from google.generativeai.notebook import parsed_args_lib +from google.generativeai.notebook import post_process_utils +from google.generativeai.notebook import py_utils + + +class CompileCommand(command.Command): + """Implementation of the "compile" command.""" + + def __init__( + self, + models: model_registry.ModelRegistry, + env: ipython_env.IPythonEnv | None = None, + ): + """Constructor. + + Args: + models: ModelRegistry instance. + env: The IPythonEnv environment. + """ + super().__init__() + self._models = models + self._ipython_env = env + + def execute( + self, + parsed_args: parsed_args_lib.ParsedArgs, + cell_content: str, + post_processing_fns: Sequence[post_process_utils.ParsedPostProcessExpr], + ) -> str: + llm_fn = command_utils.create_llm_function( + models=self._models, + env=self._ipython_env, + parsed_args=parsed_args, + cell_content=cell_content, + post_processing_fns=post_processing_fns, + ) + + py_utils.set_py_var(parsed_args.compile_save_name, llm_fn) + return "Saved function to Python variable: {}".format(parsed_args.compile_save_name) + + def parse_post_processing_tokens( + self, tokens: Sequence[Sequence[str]] + ) -> Sequence[post_process_utils.ParsedPostProcessExpr]: + return post_process_utils.resolve_post_processing_tokens(tokens) diff --git a/.venv/lib/python3.11/site-packages/google/generativeai/notebook/eval_cmd.py b/.venv/lib/python3.11/site-packages/google/generativeai/notebook/eval_cmd.py new file mode 100644 index 0000000000000000000000000000000000000000..3806e8ad4d0368971409adea7bf75f9ae10c14e9 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/google/generativeai/notebook/eval_cmd.py @@ -0,0 +1,75 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""The eval command.""" +from __future__ import annotations + +from typing import Sequence + +from google.generativeai.notebook import command +from google.generativeai.notebook import command_utils +from google.generativeai.notebook import input_utils +from google.generativeai.notebook import ipython_env +from google.generativeai.notebook import model_registry +from google.generativeai.notebook import output_utils +from google.generativeai.notebook import parsed_args_lib +from google.generativeai.notebook import post_process_utils +import pandas + + +class EvalCommand(command.Command): + """Implementation of "eval" command.""" + + def __init__( + self, + models: model_registry.ModelRegistry, + env: ipython_env.IPythonEnv | None = None, + ): + """Constructor. + + Args: + models: ModelRegistry instance. + env: The IPythonEnv environment. + """ + super().__init__() + self._models = models + self._ipython_env = env + + def execute( + self, + parsed_args: parsed_args_lib.ParsedArgs, + cell_content: str, + post_processing_fns: Sequence[post_process_utils.ParsedPostProcessExpr], + ) -> pandas.DataFrame: + # We expect CmdLineParser to have already read the inputs once to validate + # that the placeholders in the prompt are present in the inputs, so we can + # suppress the status messages here. + inputs = input_utils.join_inputs_sources(parsed_args, suppress_status_msgs=True) + + llm_cmp_fn = command_utils.create_llm_eval_function( + models=self._models, + env=self._ipython_env, + parsed_args=parsed_args, + cell_content=cell_content, + post_processing_fns=post_processing_fns, + ) + + results = llm_cmp_fn(inputs=inputs) + output_utils.write_to_outputs(results=results, parsed_args=parsed_args) + return results.as_pandas_dataframe() + + def parse_post_processing_tokens( + self, tokens: Sequence[Sequence[str]] + ) -> Sequence[post_process_utils.ParsedPostProcessExpr]: + return post_process_utils.resolve_post_processing_tokens(tokens) diff --git a/.venv/lib/python3.11/site-packages/google/generativeai/notebook/flag_def.py b/.venv/lib/python3.11/site-packages/google/generativeai/notebook/flag_def.py new file mode 100644 index 0000000000000000000000000000000000000000..469dee711a71d83c00a7a85aa50a5129a0be1f5e --- /dev/null +++ b/.venv/lib/python3.11/site-packages/google/generativeai/notebook/flag_def.py @@ -0,0 +1,503 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Classes that define arguments for populating ArgumentParser. + +The argparse module's ArgumentParser.add_argument() takes several parameters and +is quite customizable. However this can lead to bugs where arguments do not +behave as expected. + +For better ease-of-use and better testability, define a set of classes for the +types of flags used by LLM Magics. + +Sample usage: + + str_flag = SingleValueFlagDef(name="title", required=True) + enum_flag = EnumFlagDef(name="colors", required=True, enum_type=ColorsEnum) + + str_flag.add_argument_to_parser(my_parser) + enum_flag.add_argument_to_parser(my_parser) +""" +from __future__ import annotations + +import abc +import argparse +import dataclasses +import enum +from typing import Any, Callable, Sequence, Tuple, Union + +from google.generativeai.notebook.lib import llmfn_inputs_source +from google.generativeai.notebook.lib import llmfn_outputs + +# These are the intermediate types that argparse.ArgumentParser.parse_args() +# will pass command line arguments into. +_PARSETYPES = Union[str, int, float] +# These are the final result types that the intermediate parsed values will be +# converted into. It is a superset of _PARSETYPES because we support converting +# the parsed type into a more precise type, e.g. from str to Enum. +_DESTTYPES = Union[ + _PARSETYPES, + enum.Enum, + Tuple[str, Callable[[str, str], Any]], + Sequence[str], # For --compare_fn + llmfn_inputs_source.LLMFnInputsSource, # For --ground_truth + llmfn_outputs.LLMFnOutputsSink, # For --inputs # For --outputs +] + +# The signature of a function that converts a command line argument from the +# intermediate parsed type to the result type. +_PARSEFN = Callable[[_PARSETYPES], _DESTTYPES] + + +def _get_type_name(x: type[Any]) -> str: + try: + return x.__name__ + except AttributeError: + return str(x) + + +def _validate_flag_name(name: str) -> str: + """Validation for long and short names for flags.""" + if not name: + raise ValueError("Cannot be empty") + if name[0] == "-": + raise ValueError("Cannot start with dash") + return name + + +@dataclasses.dataclass(frozen=True) +class FlagDef(abc.ABC): + """Abstract base class for flag definitions. + + Attributes: + name: Long name, e.g. "colors" will define the flag "--colors". + required: Whether the flag must be provided on the command line. + short_name: Optional short name. + parse_type: The type that ArgumentParser should parse the command line + argument to. + dest_type: The type that the parsed value is converted to. This is used when + we want ArgumentParser to parse as one type, then convert to a different + type. E.g. for enums we parse as "str" then convert to the desired enum + type in order to provide cleaner help messages. + parse_to_dest_type_fn: If provided, this function will be used to convert + the value from `parse_type` to `dest_type`. This can be used for + validation as well. + choices: If provided, limit the set of acceptable values to these choices. + help_msg: If provided, adds help message when -h is used in the command + line. + """ + + name: str + required: bool = False + + short_name: str | None = None + + parse_type: type[_PARSETYPES] = str + dest_type: type[_DESTTYPES] | None = None + parse_to_dest_type_fn: _PARSEFN | None = None + + choices: list[_PARSETYPES] | None = None + help_msg: str | None = None + + @abc.abstractmethod + def add_argument_to_parser(self, parser: argparse.ArgumentParser) -> None: + """Adds this flag as an argument to `parser`. + + Child classes should implement this as a call to parser.add_argument() + with the appropriate parameters. + + Args: + parser: The parser to which this argument will be added. + """ + + @abc.abstractmethod + def _do_additional_validation(self) -> None: + """For child classes to do additional validation.""" + + def _get_dest_type(self) -> type[_DESTTYPES]: + """Returns the final converted type.""" + return self.parse_type if self.dest_type is None else self.dest_type + + def _get_parse_to_dest_type_fn( + self, + ) -> _PARSEFN: + """Returns a function to convert from parse_type to dest_type.""" + if self.parse_to_dest_type_fn is not None: + return self.parse_to_dest_type_fn + + dest_type = self._get_dest_type() + if dest_type == self.parse_type: + return lambda x: x + else: + return dest_type + + def __post_init__(self): + _validate_flag_name(self.name) + if self.short_name is not None: + _validate_flag_name(self.short_name) + + self._do_additional_validation() + + +def _has_non_default_value( + namespace: argparse.Namespace, + dest: str, + has_default: bool = False, + default_value: Any = None, +) -> bool: + """Returns true if `namespace.dest` is set to a non-default value. + + Args: + namespace: The Namespace that is populated by ArgumentParser. + dest: The attribute in the Namespace to be populated. + has_default: "None" is a valid default value so we use an additional + `has_default` boolean to indicate that `default_value` is present. + default_value: The default value to use when `has_default` is True. + + Returns: + Whether namespace.dest is set to something other than the default value. + """ + if not hasattr(namespace, dest): + return False + + if not has_default: + # No default value provided so `namespace.dest` cannot possibly be equal to + # the default value. + return True + + return getattr(namespace, dest) != default_value + + +class _SingleValueStoreAction(argparse.Action): + """Custom Action for storing a value in an argparse.Namespace. + + This action checks that the flag is specified at-most once. + """ + + def __init__( + self, + option_strings, + dest, + dest_type: type[Any], + parse_to_dest_type_fn: _PARSEFN, + **kwargs, + ): + super().__init__(option_strings, dest, **kwargs) + self._dest_type = dest_type + self._parse_to_dest_type_fn = parse_to_dest_type_fn + + def __call__( + self, + parser: argparse.ArgumentParser, + namespace: argparse.Namespace, + values: str | Sequence[Any] | None, + option_string: str | None = None, + ): + # Because `nargs` is set to 1, `values` must be a Sequence, rather + # than a string. + assert not isinstance(values, str) and not isinstance(values, bytes) + + if _has_non_default_value( + namespace, + self.dest, + has_default=hasattr(self, "default"), + default_value=getattr(self, "default"), + ): + raise argparse.ArgumentError(self, "Cannot set {} more than once".format(option_string)) + + try: + converted_value = self._parse_to_dest_type_fn(values[0]) + except Exception as e: + raise argparse.ArgumentError( + self, + 'Error with value "{}", got {}: {}'.format(values[0], _get_type_name(type(e)), e), + ) + + if not isinstance(converted_value, self._dest_type): + raise RuntimeError( + "Converted to wrong type, expected {} got {}".format( + _get_type_name(self._dest_type), + _get_type_name(type(converted_value)), + ) + ) + setattr(namespace, self.dest, converted_value) + + +class _MultiValuesAppendAction(argparse.Action): + """Custom Action for appending values in an argparse.Namespace. + + This action checks that the flag is specified at-most once. + """ + + def __init__( + self, + option_strings, + dest, + dest_type: type[Any], + parse_to_dest_type_fn: _PARSEFN, + **kwargs, + ): + super().__init__(option_strings, dest, **kwargs) + self._dest_type = dest_type + self._parse_to_dest_type_fn = parse_to_dest_type_fn + + def __call__( + self, + parser: argparse.ArgumentParser, + namespace: argparse.Namespace, + values: str | Sequence[Any] | None, + option_string: str | None = None, + ): + # Because `nargs` is set to "+", `values` must be a Sequence, rather + # than a string. + assert not isinstance(values, str) and not isinstance(values, bytes) + + curr_value = getattr(namespace, self.dest) + if curr_value: + raise argparse.ArgumentError(self, "Cannot set {} more than once".format(option_string)) + + for value in values: + try: + converted_value = self._parse_to_dest_type_fn(value) + except Exception as e: + raise argparse.ArgumentError( + self, + 'Error with value "{}", got {}: {}'.format( + values[0], _get_type_name(type(e)), e + ), + ) + + if not isinstance(converted_value, self._dest_type): + raise RuntimeError( + "Converted to wrong type, expected {} got {}".format( + self._dest_type, type(converted_value) + ) + ) + if converted_value in curr_value: + raise argparse.ArgumentError(self, 'Duplicate values "{}"'.format(value)) + + curr_value.append(converted_value) + + +class _BooleanValueStoreAction(argparse.Action): + """Custom Action for setting a boolean value in argparse.Namespace. + + The boolean flag expects the default to be False and will set the value to + True. + This action checks that the flag is specified at-most once. + """ + + def __init__( + self, + option_strings, + dest, + **kwargs, + ): + super().__init__(option_strings, dest, **kwargs) + + def __call__( + self, + parser: argparse.ArgumentParser, + namespace: argparse.Namespace, + values: str | Sequence[Any] | None, + option_string: str | None = None, + ): + if _has_non_default_value( + namespace, + self.dest, + has_default=True, + default_value=False, + ): + raise argparse.ArgumentError(self, "Cannot set {} more than once".format(option_string)) + + setattr(namespace, self.dest, True) + + +@dataclasses.dataclass(frozen=True) +class SingleValueFlagDef(FlagDef): + """Definition for a flag that takes a single value. + + Sample usage: + # This defines a flag that can be specified on the command line as: + # --count=10 + flag = SingleValueFlagDef(name="count", parse_type=int, required=True) + flag.add_argument_to_parser(argument_parser) + + Attributes: + default_value: Default value for optional flags. + """ + + class _DefaultValue(enum.Enum): + """Special value to represent "no value provided". + + "None" can be used as a default value, so in order to differentiate between + "None" and "no value provided", create a special value for "no value + provided". + """ + + NOT_SET = None + + default_value: _DESTTYPES | _DefaultValue | None = _DefaultValue.NOT_SET + + def _has_default_value(self) -> bool: + """Returns whether `default_value` has been provided.""" + return self.default_value != SingleValueFlagDef._DefaultValue.NOT_SET + + def add_argument_to_parser(self, parser: argparse.ArgumentParser) -> None: + args = ["--" + self.name] + if self.short_name is not None: + args += ["-" + self.short_name] + + kwargs = {} + if self._has_default_value(): + kwargs["default"] = self.default_value + if self.choices is not None: + kwargs["choices"] = self.choices + if self.help_msg is not None: + kwargs["help"] = self.help_msg + + parser.add_argument( + *args, + action=_SingleValueStoreAction, + type=self.parse_type, + dest_type=self._get_dest_type(), + parse_to_dest_type_fn=self._get_parse_to_dest_type_fn(), + required=self.required, + nargs=1, + **kwargs, + ) + + def _do_additional_validation(self) -> None: + if self.required: + if self._has_default_value(): + raise ValueError("Required flags cannot have default value") + else: + if not self._has_default_value(): + raise ValueError("Optional flags must have a default value") + + if self._has_default_value() and self.default_value is not None: + if not isinstance(self.default_value, self._get_dest_type()): + raise ValueError("Default value must be of the same type as the destination type") + + +class EnumFlagDef(SingleValueFlagDef): + """Definition for a flag that takes a value from an Enum. + + Sample usage: + # This defines a flag that can be specified on the command line as: + # --color=red + flag = SingleValueFlagDef(name="color", enum_type=ColorsEnum, + required=True) + flag.add_argument_to_parser(argument_parser) + """ + + def __init__(self, *args, enum_type: type[enum.Enum], **kwargs): + if not issubclass(enum_type, enum.Enum): + raise TypeError('"enum_type" must be of type Enum') + + # These properties are set by "enum_type" so don"t let the caller set them. + if "parse_type" in kwargs: + raise ValueError('Cannot set "parse_type" for EnumFlagDef; set "enum_type" instead') + kwargs["parse_type"] = str + + if "dest_type" in kwargs: + raise ValueError('Cannot set "dest_type" for EnumFlagDef; set "enum_type" instead') + kwargs["dest_type"] = enum_type + + if "choices" in kwargs: + # Verify that entries in `choices` are valid enum values. + for x in kwargs["choices"]: + try: + enum_type(x) + except ValueError: + raise ValueError('Invalid value in "choices": "{}"'.format(x)) from None + else: + kwargs["choices"] = [x.value for x in enum_type] + + super().__init__(*args, **kwargs) + + +class MultiValuesFlagDef(FlagDef): + """Definition for a flag that takes multiple values. + + Sample usage: + # This defines a flag that can be specified on the command line as: + # --colors=red green blue + flag = MultiValuesFlagDef(name="colors", parse_type=str, required=True) + flag.add_argument_to_parser(argument_parser) + """ + + def add_argument_to_parser(self, parser: argparse.ArgumentParser) -> None: + args = ["--" + self.name] + if self.short_name is not None: + args += ["-" + self.short_name] + + kwargs = {} + if self.choices is not None: + kwargs["choices"] = self.choices + if self.help_msg is not None: + kwargs["help"] = self.help_msg + + parser.add_argument( + *args, + action=_MultiValuesAppendAction, + type=self.parse_type, + dest_type=self._get_dest_type(), + parse_to_dest_type_fn=self._get_parse_to_dest_type_fn(), + required=self.required, + default=[], + nargs="+", + **kwargs, + ) + + def _do_additional_validation(self) -> None: + # No additional validation needed. + pass + + +@dataclasses.dataclass(frozen=True) +class BooleanFlagDef(FlagDef): + """Definition for a Boolean flag. + + A boolean flag is always optional with a default value of False. The flag does + not take any values. Specifying the flag on the commandline will set it to + True. + """ + + def _do_additional_validation(self) -> None: + if self.dest_type is not None: + raise ValueError("dest_type cannot be set for BooleanFlagDef") + if self.parse_to_dest_type_fn is not None: + raise ValueError("parse_to_dest_type_fn cannot be set for BooleanFlagDef") + if self.choices is not None: + raise ValueError("choices cannot be set for BooleanFlagDef") + + def add_argument_to_parser(self, parser: argparse.ArgumentParser) -> None: + args = ["--" + self.name] + if self.short_name is not None: + args += ["-" + self.short_name] + + kwargs = {} + if self.help_msg is not None: + kwargs["help"] = self.help_msg + + parser.add_argument( + *args, + action=_BooleanValueStoreAction, + type=bool, + required=False, + default=False, + nargs=0, + **kwargs, + ) diff --git a/.venv/lib/python3.11/site-packages/google/generativeai/notebook/gspread_client.py b/.venv/lib/python3.11/site-packages/google/generativeai/notebook/gspread_client.py new file mode 100644 index 0000000000000000000000000000000000000000..dbc42c4531061cf5fb0d0063811b97befa1bea33 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/google/generativeai/notebook/gspread_client.py @@ -0,0 +1,232 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Module that holds a global gspread.client.Client.""" +from __future__ import annotations + +import abc +import datetime +from typing import Any, Callable, Mapping, Sequence +from google.auth import credentials +from google.generativeai.notebook import html_utils +from google.generativeai.notebook import ipython_env +from google.generativeai.notebook import sheets_id + + +# The code may be running in an environment where the gspread library has not +# been installed. +_gspread_import_error: Exception | None = None +try: + # pylint: disable-next=g-import-not-at-top + import gspread +except ImportError as e: + _gspread_import_error = e + gspread = None + +# Base class of exceptions that gspread.open(), open_by_url() and open_by_key() +# may throw. +GSpreadException = Exception if gspread is None else gspread.exceptions.GSpreadException # type: ignore + + +class SpreadsheetNotFoundError(RuntimeError): + pass + + +def _get_import_error() -> Exception: + return RuntimeError('"gspread" module not imported, got: {}'.format(_gspread_import_error)) + + +class GSpreadClient(abc.ABC): + """Wrapper around gspread.client.Client. + + This adds a layer of indirection for us to inject mocks for testing. + """ + + @abc.abstractmethod + def validate(self, sid: sheets_id.SheetsIdentifier) -> None: + """Validates that `name` is the name of a Google Sheets document. + + Raises an exception if false. + + Args: + sid: The identifier for the document. + """ + + @abc.abstractmethod + def get_all_records( + self, + sid: sheets_id.SheetsIdentifier, + worksheet_id: int, + ) -> tuple[Sequence[Mapping[str, str]], Callable[[], None]]: + """Returns all records for a Google Sheets worksheet.""" + + @abc.abstractmethod + def write_records( + self, + sid: sheets_id.SheetsIdentifier, + rows: Sequence[Sequence[Any]], + ) -> None: + """Writes results to a new worksheet to the Google Sheets document.""" + + +class GSpreadClientImpl(GSpreadClient): + """Concrete implementation of GSpreadClient.""" + + def __init__(self, client: Any, env: ipython_env.IPythonEnv | None): + """Constructor. + + Args: + client: Instance of gspread.client.Client. + env: Optional instance of IPythonEnv. This is used to display messages + such as the URL of the output Worksheet. + """ + self._client = client + self._ipython_env = env + + def _open(self, sid: sheets_id.SheetsIdentifier): + """Opens a Sheets document from `sid`. + + Args: + sid: The identifier for the Sheets document. + + Raises: + SpreadsheetNotFoundError: If the Sheets document cannot be found or + cannot be opened. + + Returns: + A gspread.Worksheet instance representing the worksheet referred to by + `sid`. + """ + try: + if sid.name(): + return self._client.open(sid.name()) + if sid.key(): + return self._client.open_by_key(str(sid.key())) + if sid.url(): + return self._client.open_by_url(str(sid.url())) + except GSpreadException as exc: + raise SpreadsheetNotFoundError("Unable to find Sheets with {}".format(sid)) from exc + raise SpreadsheetNotFoundError("Invalid sheets_id.SheetsIdentifier") + + def validate(self, sid: sheets_id.SheetsIdentifier) -> None: + self._open(sid) + + def get_all_records( + self, + sid: sheets_id.SheetsIdentifier, + worksheet_id: int, + ) -> tuple[Sequence[Mapping[str, str]], Callable[[], None]]: + sheet = self._open(sid) + worksheet = sheet.get_worksheet(worksheet_id) + + if self._ipython_env is not None: + env = self._ipython_env + + def _display_fn(): + env.display_html( + "Reading inputs from worksheet {}".format( + html_utils.get_anchor_tag( + url=sheets_id.SheetsURL(worksheet.url), + text="{} in {}".format(worksheet.title, sheet.title), + ) + ) + ) + + else: + + def _display_fn(): + print("Reading inputs from worksheet {} in {}".format(worksheet.title, sheet.title)) + + return worksheet.get_all_records(), _display_fn + + def write_records( + self, + sid: sheets_id.SheetsIdentifier, + rows: Sequence[Sequence[Any]], + ) -> None: + sheet = self._open(sid) + + # Create a new Worksheet. + # `title` has to be carefully constructed: some characters like colon ":" + # will not work with gspread in Worksheet.append_rows(). + current_datetime = datetime.datetime.now() + title = f"Results {current_datetime:%Y_%m_%d} ({current_datetime:%s})" + + # append_rows() will resize the worksheet as needed, so `rows` and `cols` + # can be set to 1 to create a worksheet with only a single cell. + worksheet = sheet.add_worksheet(title=title, rows=1, cols=1) + worksheet.append_rows(values=rows) + + if self._ipython_env is not None: + self._ipython_env.display_html( + "Results written to new worksheet {}".format( + html_utils.get_anchor_tag( + url=sheets_id.SheetsURL(worksheet.url), + text="{} in {}".format(worksheet.title, sheet.title), + ) + ) + ) + else: + print("Results written to new worksheet {} in {}".format(worksheet.title, sheet.title)) + + +class NullGSpreadClient(GSpreadClient): + """Null-object implementation of GSpreadClient. + + This class raises an error if any of its methods are called. It is used when + the gspread library is not available. + """ + + def validate(self, sid: sheets_id.SheetsIdentifier) -> None: + raise _get_import_error() + + def get_all_records( + self, + sid: sheets_id.SheetsIdentifier, + worksheet_id: int, + ) -> tuple[Sequence[Mapping[str, str]], Callable[[], None]]: + raise _get_import_error() + + def write_records( + self, + sid: sheets_id.SheetsIdentifier, + rows: Sequence[Sequence[Any]], + ) -> None: + raise _get_import_error() + + +# Global instance of gspread client. +_gspread_client: GSpreadClient | None = None + + +def authorize(creds: credentials.Credentials, env: ipython_env.IPythonEnv | None) -> None: + """Sets up credential for gspreads.""" + global _gspread_client + if gspread is not None: + client = gspread.authorize(creds) # type: ignore + _gspread_client = GSpreadClientImpl(client=client, env=env) + else: + _gspread_client = NullGSpreadClient() + + +def get_client() -> GSpreadClient: + if not _gspread_client: + raise RuntimeError("Must call authorize() first") + return _gspread_client + + +def testonly_set_client(client: GSpreadClient) -> None: + """Overrides the global client for testing.""" + global _gspread_client + _gspread_client = client diff --git a/.venv/lib/python3.11/site-packages/google/generativeai/notebook/html_utils.py b/.venv/lib/python3.11/site-packages/google/generativeai/notebook/html_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..df7ab6b89855d31a363b27dc26de0e03c77be50e --- /dev/null +++ b/.venv/lib/python3.11/site-packages/google/generativeai/notebook/html_utils.py @@ -0,0 +1,46 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Utilities for generating HTML.""" +from __future__ import annotations + +from xml.etree import ElementTree +from google.generativeai.notebook import sheets_id + + +def get_anchor_tag(url: sheets_id.SheetsURL, text: str) -> str: + """Returns a HTML string representing an anchor tag. + + This class uses the xml.etree library to handle HTML escaping. + + Args: + url: The Sheets URL to link to. + text: The text body of the link. + + Returns: + A string representing a HTML fragment. + """ + tag = ElementTree.Element( + "a", + attrib={ + # Open in a new window/tab + "target": "_blank", + # See: + # https://developer.chrome.com/en/docs/lighthouse/best-practices/external-anchors-use-rel-noopener/ + "rel": "noopener", + "href": str(url), + }, + ) + tag.text = text if text else "link" + return ElementTree.tostring(tag, encoding="unicode", method="html") diff --git a/.venv/lib/python3.11/site-packages/google/generativeai/notebook/input_utils.py b/.venv/lib/python3.11/site-packages/google/generativeai/notebook/input_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..15575b5999c32689f3143c52fd0fbad30c33c805 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/google/generativeai/notebook/input_utils.py @@ -0,0 +1,73 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Utilities for handling input variables.""" +from __future__ import annotations + +from typing import Callable, Mapping + +from google.generativeai.notebook import parsed_args_lib +from google.generativeai.notebook import py_utils +from google.generativeai.notebook.lib import llmfn_input_utils +from google.generativeai.notebook.lib import llmfn_inputs_source + + +class _NormalizedInputsSource(llmfn_inputs_source.LLMFnInputsSource): + """Wrapper around NormalizedInputsList. + + By design LLMFunction does not take NormalizedInputsList as input because + NormalizedInputsList is an internal representation so we want to minimize + exposure to the caller. + + When we have inputs already in normalized format (e.g. from + join_prompt_inputs()) we can wrap it as an LLMFnInputsSource to pass as an + input to LLMFunction. + """ + + def __init__(self, normalized_inputs: llmfn_inputs_source.NormalizedInputsList): + super().__init__() + self._normalized_inputs = normalized_inputs + + def _to_normalized_inputs_impl( + self, + ) -> tuple[llmfn_inputs_source.NormalizedInputsList, Callable[[], None]]: + return self._normalized_inputs, lambda: None + + +def get_inputs_source_from_py_var( + var_name: str, +) -> llmfn_inputs_source.LLMFnInputsSource: + data = py_utils.get_py_var(var_name) + if isinstance(data, llmfn_inputs_source.LLMFnInputsSource): + # No conversion needed. + return data + normalized_inputs = llmfn_input_utils.to_normalized_inputs(data) + return _NormalizedInputsSource(normalized_inputs) + + +def join_inputs_sources( + parsed_args: parsed_args_lib.ParsedArgs, + suppress_status_msgs: bool = False, +) -> llmfn_inputs_source.LLMFnInputsSource: + """Get a single combined input source from `parsed_args.""" + combined_inputs: list[Mapping[str, str]] = [] + for source in parsed_args.inputs: + combined_inputs.extend( + source.to_normalized_inputs(suppress_status_msgs=suppress_status_msgs) + ) + for source in parsed_args.sheets_input_names: + combined_inputs.extend( + source.to_normalized_inputs(suppress_status_msgs=suppress_status_msgs) + ) + return _NormalizedInputsSource(combined_inputs) diff --git a/.venv/lib/python3.11/site-packages/google/generativeai/notebook/ipython_env_impl.py b/.venv/lib/python3.11/site-packages/google/generativeai/notebook/ipython_env_impl.py new file mode 100644 index 0000000000000000000000000000000000000000..c46222438464016b65a47857f09341eb486763fb --- /dev/null +++ b/.venv/lib/python3.11/site-packages/google/generativeai/notebook/ipython_env_impl.py @@ -0,0 +1,30 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""IPythonEnvImpl.""" +from __future__ import annotations + +from typing import Any +from google.generativeai.notebook import ipython_env +from IPython.core import display as ipython_display + + +class IPythonEnvImpl(ipython_env.IPythonEnv): + """Concrete implementation of IPythonEnv.""" + + def display(self, x: Any) -> None: + ipython_display.display(x) + + def display_html(self, x: str) -> None: + ipython_display.display(ipython_display.HTML(x)) diff --git a/.venv/lib/python3.11/site-packages/google/generativeai/notebook/magics.py b/.venv/lib/python3.11/site-packages/google/generativeai/notebook/magics.py new file mode 100644 index 0000000000000000000000000000000000000000..8b6af51b931f7e4afd60f76ec4f046831975959e --- /dev/null +++ b/.venv/lib/python3.11/site-packages/google/generativeai/notebook/magics.py @@ -0,0 +1,150 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Colab Magics class. + +Installs %%llm magics. +""" +from __future__ import annotations + +import abc + +from google.auth import credentials +from google.generativeai import client as genai +from google.generativeai.notebook import gspread_client +from google.generativeai.notebook import ipython_env +from google.generativeai.notebook import ipython_env_impl +from google.generativeai.notebook import magics_engine +from google.generativeai.notebook import post_process_utils +from google.generativeai.notebook import sheets_utils + +import IPython +from IPython.core import magic + + +# Set the UA to distinguish the magic from the client. Do this at import-time +# so that a user can still call `genai.configure()`, and both their settings +# and this are honored. +genai.USER_AGENT = "genai-py-magic" + +SheetsInputs = sheets_utils.SheetsInputs +SheetsOutputs = sheets_utils.SheetsOutputs + +# Decorator functions for post-processing. +post_process_add_fn = post_process_utils.post_process_add_fn +post_process_replace_fn = post_process_utils.post_process_replace_fn + +# Globals. +_ipython_env: ipython_env.IPythonEnv | None = None + + +def _get_ipython_env() -> ipython_env.IPythonEnv: + """Lazily constructs and returns a global IPythonEnv instance.""" + global _ipython_env + if _ipython_env is None: + _ipython_env = ipython_env_impl.IPythonEnvImpl() + return _ipython_env + + +def authorize(creds: credentials.Credentials) -> None: + """Sets up credentials. + + This is used for interacting Google APIs, such as Google Sheets. + + Args: + creds: The credentials that will be used (e.g. to read from Google Sheets.) + """ + gspread_client.authorize(creds=creds, env=_get_ipython_env()) + + +class AbstractMagics(abc.ABC): + """Defines interface to Magics class.""" + + @abc.abstractmethod + def llm(self, cell_line: str | None, cell_body: str | None): + """Perform various LLM-related operations. + + Args: + cell_line: String to pass to the MagicsEngine. + cell_body: Contents of the cell body. + """ + raise NotImplementedError() + + +class MagicsImpl(AbstractMagics): + """Actual class implementing the magics functionality. + + We use a separate class to ensure a single, global instance + of the magics class. + """ + + def __init__(self): + self._engine = magics_engine.MagicsEngine(env=_get_ipython_env()) + + def llm(self, cell_line: str | None, cell_body: str | None): + """Perform various LLM-related operations. + + Args: + cell_line: String to pass to the MagicsEngine. + cell_body: Contents of the cell body. + + Returns: + Results from running MagicsEngine. + """ + cell_line = cell_line or "" + cell_body = cell_body or "" + return self._engine.execute_cell(cell_line, cell_body) + + +@magic.magics_class +class Magics(magic.Magics): + """Class to register the magic with Colab. + + Objects of this class delegate all calls to a single, + global instance. + """ + + # Global instance + _instance = None + + @classmethod + def get_instance(cls) -> AbstractMagics: + """Retrieve global instance of the Magics object.""" + if cls._instance is None: + cls._instance = MagicsImpl() + return cls._instance + + @magic.line_cell_magic + def llm(self, cell_line: str | None, cell_body: str | None): + """Perform various LLM-related operations. + + Args: + cell_line: String to pass to the MagicsEngine. + cell_body: Contents of the cell body. + + Returns: + Results from running MagicsEngine. + """ + return Magics.get_instance().llm(cell_line=cell_line, cell_body=cell_body) + + @magic.line_cell_magic + def palm(self, cell_line: str | None, cell_body: str | None): + return self.llm(cell_line, cell_body) + + @magic.line_cell_magic + def gemini(self, cell_line: str | None, cell_body: str | None): + return self.llm(cell_line, cell_body) + + +IPython.get_ipython().register_magics(Magics) diff --git a/.venv/lib/python3.11/site-packages/google/generativeai/notebook/magics_engine.py b/.venv/lib/python3.11/site-packages/google/generativeai/notebook/magics_engine.py new file mode 100644 index 0000000000000000000000000000000000000000..ea6eeb46f69fcd53fd9d61375a21df62d59a6aa7 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/google/generativeai/notebook/magics_engine.py @@ -0,0 +1,124 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""MagicsEngine class.""" +from __future__ import annotations + +from typing import AbstractSet, Sequence + +from google.generativeai.notebook import argument_parser +from google.generativeai.notebook import cmd_line_parser +from google.generativeai.notebook import command +from google.generativeai.notebook import compare_cmd +from google.generativeai.notebook import compile_cmd +from google.generativeai.notebook import eval_cmd +from google.generativeai.notebook import ipython_env +from google.generativeai.notebook import model_registry +from google.generativeai.notebook import parsed_args_lib +from google.generativeai.notebook import post_process_utils +from google.generativeai.notebook import run_cmd +from google.generativeai.notebook.lib import prompt_utils + + +class MagicsEngine: + """Implementation of functionality used by Magics. + + This class provides the implementation for Magics, decoupled from the + details of integrating with Colab Magics such as registration. + """ + + def __init__( + self, + registry: model_registry.ModelRegistry | None = None, + env: ipython_env.IPythonEnv | None = None, + ): + self._ipython_env = env + models = registry or model_registry.ModelRegistry() + self._cmd_handlers: dict[parsed_args_lib.CommandName, command.Command] = { + parsed_args_lib.CommandName.RUN_CMD: run_cmd.RunCommand(models=models, env=env), + parsed_args_lib.CommandName.COMPILE_CMD: compile_cmd.CompileCommand( + models=models, env=env + ), + parsed_args_lib.CommandName.COMPARE_CMD: compare_cmd.CompareCommand(env=env), + parsed_args_lib.CommandName.EVAL_CMD: eval_cmd.EvalCommand(models=models, env=env), + } + + def parse_line( + self, + line: str, + placeholders: AbstractSet[str], + ) -> tuple[parsed_args_lib.ParsedArgs, parsed_args_lib.PostProcessingTokens]: + return cmd_line_parser.CmdLineParser().parse_line(line, placeholders) + + def _get_handler(self, line: str, placeholders: AbstractSet[str]) -> tuple[ + command.Command, + parsed_args_lib.ParsedArgs, + Sequence[post_process_utils.ParsedPostProcessExpr], + ]: + """Given the command line, parse and return all components. + + Args: + line: The LLM Magics command line. + placeholders: Placeholders from prompts in the cell contents. + + Returns: + A three-tuple containing: + - The command (e.g. "run") + - Parsed arguments for the command, + - Parsed post-processing expressions + """ + parsed_args, post_processing_tokens = self.parse_line(line, placeholders) + cmd_name = parsed_args.cmd + handler = self._cmd_handlers[cmd_name] + post_processing_fns = handler.parse_post_processing_tokens(post_processing_tokens) + return handler, parsed_args, post_processing_fns + + def execute_cell(self, line: str, cell_content: str): + """Executes the supplied magic line and cell payload.""" + cell = _clean_cell(cell_content) + placeholders = prompt_utils.get_placeholders(cell) + + try: + handler, parsed_args, post_processing_fns = self._get_handler(line, placeholders) + return handler.execute(parsed_args, cell, post_processing_fns) + except argument_parser.ParserNormalExit as e: + if self._ipython_env is not None: + e.set_ipython_env(self._ipython_env) + # ParserNormalExit implements the _ipython_display_ method so it can + # be returned as the output of this cell for display. + return e + except argument_parser.ParserError as e: + e.display(self._ipython_env) + # Raise an exception to indicate that execution for this cell has + # failed. + # The exception is re-raised as SystemExit because Colab automatically + # suppresses traceback for SystemExit but not other exceptions. Because + # ParserErrors are usually due to user error (e.g. a missing required + # flag or an invalid flag value), we want to hide the traceback to + # avoid detracting the user from the error message, and we want to + # reserve exceptions-with-traceback for actual bugs and unexpected + # errors. + error_msg = "Got parser error: {}".format(e.msgs()[-1]) if e.msgs() else "" + raise SystemExit(error_msg) from e + + +def _clean_cell(cell_content: str) -> str: + # Colab includes a trailing newline in cell_content. Remove only the last + # line break from cell contents (i.e. not rstrip), so that multi-line and + # intentional line breaks are preserved, but single-line prompts don't have + # a trailing line break. + cell = cell_content + if cell.endswith("\n"): + cell = cell[:-1] + return cell diff --git a/.venv/lib/python3.11/site-packages/google/generativeai/notebook/model_registry.py b/.venv/lib/python3.11/site-packages/google/generativeai/notebook/model_registry.py new file mode 100644 index 0000000000000000000000000000000000000000..f646461cf7ee160b78a634bcc48a60e4fa715809 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/google/generativeai/notebook/model_registry.py @@ -0,0 +1,55 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Maintains set of LLM models that can be instantiated by name.""" +from __future__ import annotations + +import enum +from typing import Callable + +from google.generativeai.notebook import text_model +from google.generativeai.notebook.lib import model as model_lib + + +class ModelName(enum.Enum): + ECHO_MODEL = "echo" + TEXT_MODEL = "text" + + +class ModelRegistry: + """Registry that instantiates and caches models.""" + + DEFAULT_MODEL = ModelName.TEXT_MODEL + + def __init__(self): + self._model_cache: dict[ModelName, model_lib.AbstractModel] = {} + self._model_constructors: dict[ModelName, Callable[[], model_lib.AbstractModel]] = { + ModelName.ECHO_MODEL: model_lib.EchoModel, + ModelName.TEXT_MODEL: text_model.TextModel, + } + + def get_model(self, model_name: ModelName) -> model_lib.AbstractModel: + """Given `model_name`, return the corresponding Model instance. + + Model instances are cached and reused for the same `model_name`. + + Args: + model_name: The name of the model. + + Returns: + The corresponding model instance for `model_name`. + """ + if model_name not in self._model_cache: + self._model_cache[model_name] = self._model_constructors[model_name]() + return self._model_cache[model_name] diff --git a/.venv/lib/python3.11/site-packages/google/generativeai/notebook/parsed_args_lib.py b/.venv/lib/python3.11/site-packages/google/generativeai/notebook/parsed_args_lib.py new file mode 100644 index 0000000000000000000000000000000000000000..dd42b1dda8d3b48db0859b90a9fab3eeef1a1cd7 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/google/generativeai/notebook/parsed_args_lib.py @@ -0,0 +1,85 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Results from parsing the commandline. + +This module separates the results from commandline parsing from the parser +itself so that classes that operate on the results (e.g. the subclasses of +Command) do not have to depend on the commandline parser as well. +""" +from __future__ import annotations + +import dataclasses +import enum +from typing import Any, Callable, Sequence + +from google.generativeai.notebook import model_registry +from google.generativeai.notebook.lib import llm_function +from google.generativeai.notebook.lib import llmfn_inputs_source +from google.generativeai.notebook.lib import llmfn_outputs +from google.generativeai.notebook.lib import model as model_lib + +# Post processing tokens are represented as a sequence of sequence of tokens, +# because the pipe operator could be used more than once. +PostProcessingTokens = Sequence[Sequence[str]] + +# The type of function taken by the "compare_fn" flag. +# It takes the text_results of the left- and right-hand side functions as +# inputs and returns a comparison result. +TextResultCompareFn = Callable[[str, str], Any] + + +class CommandName(enum.Enum): + RUN_CMD = "run" + COMPILE_CMD = "compile" + COMPARE_CMD = "compare" + EVAL_CMD = "eval" + + +@dataclasses.dataclass(frozen=True) +class ParsedArgs: + """The results of parsing the command line.""" + + cmd: CommandName + + # For run, compile and eval commands. + model_args: model_lib.ModelArguments + model_type: model_registry.ModelName | None = None + unique: bool = False + + # For run, compare and eval commands. + inputs: Sequence[llmfn_inputs_source.LLMFnInputsSource] = dataclasses.field( + default_factory=list + ) + sheets_input_names: Sequence[llmfn_inputs_source.LLMFnInputsSource] = dataclasses.field( + default_factory=list + ) + + outputs: Sequence[llmfn_outputs.LLMFnOutputsSink] = dataclasses.field(default_factory=list) + sheets_output_names: Sequence[llmfn_outputs.LLMFnOutputsSink] = dataclasses.field( + default_factory=list + ) + + # For compile command. + compile_save_name: str | None = None + + # For compare command. + lhs_name_and_fn: tuple[str, llm_function.LLMFunction] | None = None + rhs_name_and_fn: tuple[str, llm_function.LLMFunction] | None = None + + # For compare and eval commands. + compare_fn: Sequence[tuple[str, TextResultCompareFn]] = dataclasses.field(default_factory=list) + + # For eval command. + ground_truth: Sequence[str] = dataclasses.field(default_factory=list) diff --git a/.venv/lib/python3.11/site-packages/google/generativeai/notebook/post_process_utils_test_helper.py b/.venv/lib/python3.11/site-packages/google/generativeai/notebook/post_process_utils_test_helper.py new file mode 100644 index 0000000000000000000000000000000000000000..7cdf941d27430caa3d5490f59c77347e002e3d0a --- /dev/null +++ b/.venv/lib/python3.11/site-packages/google/generativeai/notebook/post_process_utils_test_helper.py @@ -0,0 +1,32 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Helper module for post_process_utils_test.""" +from __future__ import annotations + +from google.generativeai.notebook import post_process_utils + + +def add_length(x: str) -> int: + return len(x) + + +@post_process_utils.post_process_add_fn +def add_length_decorated(x: str) -> int: + return len(x) + + +@post_process_utils.post_process_replace_fn +def to_upper(x: str) -> str: + return x.upper() diff --git a/.venv/lib/python3.11/site-packages/google/generativeai/notebook/run_cmd.py b/.venv/lib/python3.11/site-packages/google/generativeai/notebook/run_cmd.py new file mode 100644 index 0000000000000000000000000000000000000000..d1f0a1265b563b79d91bb4d28eee69d6d9a83038 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/google/generativeai/notebook/run_cmd.py @@ -0,0 +1,75 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""The run command.""" +from __future__ import annotations + +from typing import Sequence + +from google.generativeai.notebook import command +from google.generativeai.notebook import command_utils +from google.generativeai.notebook import input_utils +from google.generativeai.notebook import ipython_env +from google.generativeai.notebook import model_registry +from google.generativeai.notebook import output_utils +from google.generativeai.notebook import parsed_args_lib +from google.generativeai.notebook import post_process_utils +import pandas + + +class RunCommand(command.Command): + """Implementation of the "run" command.""" + + def __init__( + self, + models: model_registry.ModelRegistry, + env: ipython_env.IPythonEnv | None = None, + ): + """Constructor. + + Args: + models: ModelRegistry instance. + env: The IPythonEnv environment. + """ + super().__init__() + self._models = models + self._ipython_env = env + + def execute( + self, + parsed_args: parsed_args_lib.ParsedArgs, + cell_content: str, + post_processing_fns: Sequence[post_process_utils.ParsedPostProcessExpr], + ) -> pandas.DataFrame: + # We expect CmdLineParser to have already read the inputs once to validate + # that the placeholders in the prompt are present in the inputs, so we can + # suppress the status messages here. + inputs = input_utils.join_inputs_sources(parsed_args, suppress_status_msgs=True) + + llm_fn = command_utils.create_llm_function( + models=self._models, + env=self._ipython_env, + parsed_args=parsed_args, + cell_content=cell_content, + post_processing_fns=post_processing_fns, + ) + + results = llm_fn(inputs=inputs) + output_utils.write_to_outputs(results=results, parsed_args=parsed_args) + return results.as_pandas_dataframe() + + def parse_post_processing_tokens( + self, tokens: Sequence[Sequence[str]] + ) -> Sequence[post_process_utils.ParsedPostProcessExpr]: + return post_process_utils.resolve_post_processing_tokens(tokens) diff --git a/.venv/lib/python3.11/site-packages/google/generativeai/notebook/sheets_id.py b/.venv/lib/python3.11/site-packages/google/generativeai/notebook/sheets_id.py new file mode 100644 index 0000000000000000000000000000000000000000..e49af4e169939f55e403c4aaa2ddc5de8c743b69 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/google/generativeai/notebook/sheets_id.py @@ -0,0 +1,101 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Module for classes related to identifying a Sheets document.""" +from __future__ import annotations + +import re +from google.generativeai.notebook import sheets_sanitize_url + + +def _sanitize_key(key: str) -> str: + if not re.fullmatch("[a-zA-Z0-9_-]+", key): + raise ValueError('"{}" is not a valid Sheets key'.format(key)) + return key + + +class SheetsURL: + """Class that enforces safety by ensuring that URLs are sanitized.""" + + def __init__(self, url: str): + self._url: str = sheets_sanitize_url.sanitize_sheets_url(url) + + def __str__(self) -> str: + return self._url + + +class SheetsKey: + """Class that enforces safety by ensuring that keys are sanitized.""" + + def __init__(self, key: str): + self._key: str = _sanitize_key(key) + + def __str__(self) -> str: + return self._key + + +class SheetsIdentifier: + """Encapsulates a means to identify a Sheets document. + + The gspread library provides three ways to look up a Sheets document: by name, + by url and by key. An instance of this class represents exactly one of the + methods. + """ + + def __init__( + self, + name: str | None = None, + key: SheetsKey | None = None, + url: SheetsURL | None = None, + ): + """Constructor. + + Exactly one of the arguments should be provided. + + Args: + name: The name of the Sheets document. More-than-one Sheets documents can + have the same name, so this is the least precise method of identifying + the document. + key: The key of the Sheets document + url: The url to the Sheets document + + Raises: + ValueError: If the caller does not specify exactly one of name, url or + key. + """ + self._name = name + self._key = key + self._url = url + + # There should be exactly one. + num_inputs = int(bool(self._name)) + int(bool(self._key)) + int(bool(self._url)) + if num_inputs != 1: + raise ValueError("Must set exactly one of name, key or url") + + def name(self) -> str | None: + return self._name + + def key(self) -> SheetsKey | None: + return self._key + + def url(self) -> SheetsURL | None: + return self._url + + def __str__(self): + if self._name: + return "name={}".format(self._name) + elif self._key: + return "key={}".format(self._key) + else: + return "url={}".format(self._url) diff --git a/.venv/lib/python3.11/site-packages/google/generativeai/notebook/sheets_sanitize_url.py b/.venv/lib/python3.11/site-packages/google/generativeai/notebook/sheets_sanitize_url.py new file mode 100644 index 0000000000000000000000000000000000000000..5188d0b4d5564ecbf0860bacea78674d26e51e8f --- /dev/null +++ b/.venv/lib/python3.11/site-packages/google/generativeai/notebook/sheets_sanitize_url.py @@ -0,0 +1,89 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Utilities for working with URLs.""" +from __future__ import annotations + +import re +from urllib import parse + + +def _validate_url_part(part: str) -> None: + if not re.fullmatch("[a-zA-Z0-9_-]*", part): + raise ValueError('"{}" is outside the restricted character set'.format(part)) + + +def _validate_url_query_or_fragment(part: str) -> None: + for key, values in parse.parse_qs(part).items(): + _validate_url_part(key) + for value in values: + _validate_url_part(value) + + +def sanitize_sheets_url(url: str) -> str: + """Sanitize a Sheets URL. + + Run some saftey checks to check whether `url` is a Sheets URL. This is not a + general-purpose URL sanitizer. Rather, it makes use of the fact that we know + the URL has to be for Sheets so we can make a few assumptions about (e.g. the + domain). + + Args: + url: The url to sanitize. + + Returns: + The sanitized url. + + Raises: + ValueError: If `url` does not match the expected restrictions for a Sheets + URL. + """ + parse_result = parse.urlparse(url) + if parse_result.scheme != "https": + raise ValueError( + 'Scheme for Sheets url must be "https", got "{}"'.format(parse_result.scheme) + ) + if parse_result.netloc not in ("docs.google.com", "sheets.googleapis.com"): + raise ValueError( + 'Domain for Sheets url must be "docs.google.com", got "{}"'.format(parse_result.netloc) + ) + + # Path component. + try: + for fragment in parse_result.path.split("/"): + _validate_url_part(fragment) + except ValueError as exc: + raise ValueError('Invalid path for Sheets url, got "{}"'.format(parse_result.path)) from exc + + # Params component. + if parse_result.params: + raise ValueError('Params component must be empty, got "{}"'.format(parse_result.params)) + + # Query component. + try: + _validate_url_query_or_fragment(parse_result.query) + except ValueError as exc: + raise ValueError( + 'Invalid query for Sheets url, got "{}"'.format(parse_result.query) + ) from exc + + # Fragment component. + try: + _validate_url_query_or_fragment(parse_result.fragment) + except ValueError as exc: + raise ValueError( + 'Invalid fragment for Sheets url, got "{}"'.format(parse_result.fragment) + ) from exc + + return url diff --git a/.venv/lib/python3.11/site-packages/google/generativeai/notebook/sheets_utils.py b/.venv/lib/python3.11/site-packages/google/generativeai/notebook/sheets_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..7ec862eba1cd1a34ba21c26add9f7b760157977f --- /dev/null +++ b/.venv/lib/python3.11/site-packages/google/generativeai/notebook/sheets_utils.py @@ -0,0 +1,111 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""SheetsInputs.""" +from __future__ import annotations + +from typing import Any, Callable, Mapping, Sequence +from urllib import parse +from google.generativeai.notebook import gspread_client +from google.generativeai.notebook import sheets_id +from google.generativeai.notebook.lib import llmfn_inputs_source +from google.generativeai.notebook.lib import llmfn_outputs + + +def _try_sheet_id_as_url(value: str) -> sheets_id.SheetsIdentifier | None: + """Try to open a Sheets document with `value` as a URL.""" + try: + parse_result = parse.urlparse(value) + except ValueError: + # If there's a URL parsing error, then it's not a URL. + return None + + if parse_result.scheme: + # If it looks like a URL, try to open the document as a URL but don't fall + # back to trying as key or name since it's very unlikely that a key or name + # looks like a URL. + sid = sheets_id.SheetsIdentifier(url=sheets_id.SheetsURL(value)) + gspread_client.get_client().validate(sid) + return sid + + return None + + +def _try_sheet_id_as_key(value: str) -> sheets_id.SheetsIdentifier | None: + """Try to open a Sheets document with `value` as a key.""" + try: + sid = sheets_id.SheetsIdentifier(key=sheets_id.SheetsKey(value)) + except ValueError: + # `value` is not a well-formed Sheets key. + return None + + try: + gspread_client.get_client().validate(sid) + except gspread_client.SpreadsheetNotFoundError: + return None + return sid + + +def _try_sheet_id_as_name(value: str) -> sheets_id.SheetsIdentifier | None: + """Try to open a Sheets document with `value` as a name.""" + sid = sheets_id.SheetsIdentifier(name=value) + try: + gspread_client.get_client().validate(sid) + except gspread_client.SpreadsheetNotFoundError: + return None + return sid + + +def get_sheets_id_from_str(value: str) -> sheets_id.SheetsIdentifier: + if sid := _try_sheet_id_as_url(value): + return sid + if sid := _try_sheet_id_as_key(value): + return sid + if sid := _try_sheet_id_as_name(value): + return sid + raise RuntimeError('No Sheets found with "{}" as URL, key or name'.format(value)) + + +class SheetsInputs(llmfn_inputs_source.LLMFnInputsSource): + """Inputs to an LLMFunction from Google Sheets.""" + + def __init__(self, sid: sheets_id.SheetsIdentifier, worksheet_id: int = 0): + super().__init__() + self._sid = sid + self._worksheet_id = worksheet_id + + def _to_normalized_inputs_impl( + self, + ) -> tuple[Sequence[Mapping[str, str]], Callable[[], None]]: + return gspread_client.get_client().get_all_records( + sid=self._sid, worksheet_id=self._worksheet_id + ) + + +class SheetsOutputs(llmfn_outputs.LLMFnOutputsSink): + """Writes outputs from an LLMFunction to Google Sheets.""" + + def __init__(self, sid: sheets_id.SheetsIdentifier): + self._sid = sid + + def write_outputs(self, outputs: llmfn_outputs.LLMFnOutputsBase) -> None: + # Transpose `outputs` into a list of rows. + outputs_dict = outputs.as_dict() + outputs_rows: list[Sequence[Any]] = [list(outputs_dict.keys())] + outputs_rows.extend([list(x) for x in zip(*outputs_dict.values())]) + + gspread_client.get_client().write_records( + sid=self._sid, + rows=outputs_rows, + ) diff --git a/.venv/lib/python3.11/site-packages/google/generativeai/notebook/text_model.py b/.venv/lib/python3.11/site-packages/google/generativeai/notebook/text_model.py new file mode 100644 index 0000000000000000000000000000000000000000..7360bbfbd30730d94514525fcb9e4cc52747f1be --- /dev/null +++ b/.venv/lib/python3.11/site-packages/google/generativeai/notebook/text_model.py @@ -0,0 +1,72 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Model that uses the Text service.""" +from __future__ import annotations + +from google.api_core import retry +import google.generativeai as genai +from google.generativeai.types import generation_types +from google.generativeai.notebook.lib import model as model_lib + +_DEFAULT_MODEL = "models/gemini-1.5-flash" + + +class TextModel(model_lib.AbstractModel): + """Concrete model that uses the generate_content service.""" + + def _generate_text( + self, + prompt: str, + model: str | None = None, + temperature: float | None = None, + candidate_count: int | None = None, + ) -> generation_types.GenerateContentResponse: + gen_config = {} + if temperature is not None: + gen_config["temperature"] = temperature + if candidate_count is not None: + gen_config["candidate_count"] = candidate_count + + model_name = model or _DEFAULT_MODEL + gen_model = genai.GenerativeModel(model_name=model_name) + gc = genai.types.generation_types.GenerationConfig(**gen_config) + return gen_model.generate_content(prompt, generation_config=gc) + + def call_model( + self, + model_input: str, + model_args: model_lib.ModelArguments | None = None, + ) -> model_lib.ModelResults: + if model_args is None: + model_args = model_lib.ModelArguments() + + # Wrap the generation function here, rather than decorate, so that it + # applies to any overridden calls too. + retryable_fn = retry.Retry(retry.if_transient_error)(self._generate_text) + response = retryable_fn( + prompt=model_input, + model=model_args.model, + temperature=model_args.temperature, + candidate_count=model_args.candidate_count, + ) + + text_outputs = [] + for c in response.candidates: + text_outputs.append("".join(p.text for p in c.content.parts)) + + return model_lib.ModelResults( + model_input=model_input, + text_results=text_outputs, + ) diff --git a/.venv/lib/python3.11/site-packages/google/generativeai/operations.py b/.venv/lib/python3.11/site-packages/google/generativeai/operations.py new file mode 100644 index 0000000000000000000000000000000000000000..52fd8a1b80f0ee90cd70b73493cbf88237c87a2e --- /dev/null +++ b/.venv/lib/python3.11/site-packages/google/generativeai/operations.py @@ -0,0 +1,158 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import functools +from typing import Iterator + +from google.generativeai import protos + +from google.generativeai import client as client_lib +from google.generativeai.types import model_types +from google.api_core import operation as operation_lib + +import tqdm.auto as tqdm + + +def list_operations(*, client=None) -> Iterator[CreateTunedModelOperation]: + """Calls the API to list all operations""" + + if client is None: + client = client_lib.get_default_operations_client() + + # The client returns an iterator of Operation protos (`Iterator[google.longrunning.operations_pb2.Operation]`) + # not a gapic Operation object (`google.api_core.operation.Operation`) + operations = ( + CreateTunedModelOperation.from_proto(op, client) + for op in client.list_operations(name="", filter_="") + ) + + return operations + + +def get_operation(name: str, *, client=None) -> CreateTunedModelOperation: + """Calls the API to get a specific operation""" + if client is None: + client = client_lib.get_default_operations_client() + + op = client.get_operation(name=name) + return CreateTunedModelOperation.from_proto(op, client) + + +def delete_operation(name: str, *, client=None): + """Calls the API to delete a specific operation""" + + # Raises:google.api_core.exceptions.MethodNotImplemented: Not implemented. + if client is None: + client = client_lib.get_default_operations_client() + + return client.delete_operation(name=name) + + +class CreateTunedModelOperation(operation_lib.Operation): + @classmethod + def from_proto(cls, proto, client): + """ + result = getattr(proto, 'result', None) + if result is not None: + if result.value == b'': + del proto.result + """ + + return from_gapic( + cls=CreateTunedModelOperation, + operation=proto, + operations_client=client, + result_type=protos.TunedModel, + metadata_type=protos.CreateTunedModelMetadata, + ) + + @classmethod + def from_core_operation( + cls, + operation: operation_lib.Operation, + ): + polling = getattr(operation, "_polling", None) + retry = getattr(operation, "_retry", None) + if polling is not None: + # google.api_core v 2.11 + kwargs = {"polling": polling} + elif retry is not None: + # google.api_core v 2.10 + kwargs = {"retry": retry} + else: + kwargs = {} + return cls( + operation=operation._operation, + refresh=operation._refresh, + cancel=operation._cancel, + result_type=operation._result_type, + metadata_type=operation._metadata_type, + **kwargs, + ) + + @property + def name(self) -> str: + return self._operation.name + + def update(self): + """Refresh the current statuses in metadata/result/error""" + self._refresh_and_update() + + def wait_bar(self, **kwargs) -> Iterator[protos.CreateTunedModelMetadata]: + """A tqdm wait bar, yields `Operation` statuses until complete. + + Args: + **kwargs: passed through to `tqdm.auto.tqdm(..., **kwargs)` + + Yields: + Operation statuses as `protos.CreateTunedModelMetadata` objects. + """ + bar = tqdm.tqdm(total=self.metadata.total_steps, initial=0, **kwargs) + + # done() includes a `_refresh_and_update` + while not self.done(): + metadata = self.metadata + bar.update(self.metadata.completed_steps - bar.n) + yield metadata + metadata = self.metadata + bar.update(self.metadata.completed_steps - bar.n) + return self.result() + + def set_result(self, result: protos.TunedModel): + result = model_types.decode_tuned_model(result) + super().set_result(result) + + +def from_gapic( + cls, + *, + operation, + operations_client, + result_type, + metadata_type, + grpc_metadata=None, + **kwargs, +): + """`google.api_core.operation.from_gapic`, patched to allow subclasses.""" + refresh = functools.partial( + operations_client.get_operation, operation.name, metadata=grpc_metadata + ) + cancel = functools.partial( + operations_client.cancel_operation, + operation.name, + metadata=grpc_metadata, + ) + return cls(operation, refresh, cancel, result_type, metadata_type, **kwargs) diff --git a/.venv/lib/python3.11/site-packages/google/generativeai/permission.py b/.venv/lib/python3.11/site-packages/google/generativeai/permission.py new file mode 100644 index 0000000000000000000000000000000000000000..b2b7c15e1510d33dd037f129abe0458040e70b88 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/google/generativeai/permission.py @@ -0,0 +1,167 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +from typing import Callable + +import google.ai.generativelanguage as glm + +from google.generativeai.types import permission_types +from google.generativeai.types import retriever_types +from google.generativeai.types import model_types + + +_RESOURCE_TYPE: dict[str, str] = { + "corpus": "corpora", + "corpora": "corpora", + "tunedmodel": "tunedModels", + "tunedmodels": "tunedModels", +} + + +def _to_resource_type(x: str) -> str: + if isinstance(x, str): + x = x.lower() + resource_type = _RESOURCE_TYPE.get(x, None) + if not resource_type: + raise ValueError(f"Unsupported resource type. Got: `{x}` instead.") + + return resource_type + + +def _validate_resource_name(x: str, resource_type: str) -> None: + if resource_type == "corpora": + if not retriever_types.valid_name(x): + raise ValueError(retriever_types.NAME_ERROR_MSG.format(length=len(x), name=x)) + + elif resource_type == "tunedModels": + if not model_types.valid_tuned_model_name(x): + raise ValueError(model_types.TUNED_MODEL_NAME_ERROR_MSG.format(length=len(x), name=x)) + + else: + raise ValueError(f"Unsupported resource type: {resource_type}") + + +def _validate_permission_id(x: str) -> None: + if not permission_types.valid_id(x): + raise ValueError(permission_types.INVALID_PERMISSION_ID_MSG.format(permission_id=x)) + + +def _get_valid_name_components(name: str) -> str: + # name is of the format: resource_type/resource_name/permissions/permission_id + name_path_components = name.split("/") + if len(name_path_components) != 4: + raise ValueError( + f"Invalid name format. Expected format: \ + `resource_type//permissions/`. Got: `{name}` instead." + ) + + resource_type, resource_name, permission_placeholder, permission_id = name_path_components + resource_type = _to_resource_type(resource_type) + + permission_id = "/".join([permission_placeholder, permission_id]) + + _validate_resource_name(resource_name, resource_type) + _validate_permission_id(permission_id) + + return "/".join([resource_type, resource_name, permission_id]) + + +def _construct_name( + name: str | None = None, + resource_name: str | None = None, + permission_id: str | int | None = None, + resource_type: str | None = None, +) -> str: + # resource_name is the name of the supported resource (corpus or tunedModel as of now) for which the permission is being created. + if not name: + # if name is not provided, then try to construct name via provided resource_name and permission_id. + if not (resource_name and permission_id): + raise ValueError( + f"Invalid arguments: Either `name` or both `resource_name` and `permission_id` must be provided. Received name: {name}, resource_name: {resource_name}, permission_id: {permission_id}." + ) + if resource_type: + resource_type = _to_resource_type(resource_type) + else: + # if resource_type is not provided, then try to infer it from resource_name. + resource_path_components = resource_name.split("/") + if len(resource_path_components) != 2: + raise ValueError( + f"Invalid `resource_name` format: Expected format is `resource_type/resource_name` (2 components). Received: `{resource_name}` with {len(resource_path_components)} components." + ) + resource_type = _to_resource_type(resource_path_components[0]) + + if f"{resource_type}/" in resource_name: + name = f"{resource_name}/" + else: + name = f"{resource_type}/{resource_name}/" + + if isinstance(permission_id, int) or "permissions/" not in permission_id: + name += f"permissions/{permission_id}" + else: + name += permission_id + + # if name is provided, override resource_name and permission_id + name = _get_valid_name_components(name) + return name + + +def get_permission( + name: str | None = None, + *, + client: glm.PermissionServiceClient | None = None, + resource_name: str | None = None, + permission_id: str | int | None = None, + resource_type: str | None = None, +) -> permission_types.Permission: + """Calls the API to retrieve detailed information about a specific permission based on resource type and permission identifiers + + Args: + name: The name of the permission. + resource_name: The name of the supported resource for which the permission details are needed. + permission_id: The name of the permission. + resource_type: The type of the resource (corpus or tunedModel as of now) for which the permission details are needed. + If not provided, it will be inferred from `resource_name`. + + Returns: + The permission as an instance of `permission_types.Permission`. + """ + name = _construct_name( + name=name, + resource_name=resource_name, + permission_id=permission_id, + resource_type=resource_type, + ) + return permission_types.Permission.get(name=name, client=client) + + +async def get_permission_async( + name: str | None = None, + *, + client: glm.PermissionServiceAsyncClient | None = None, + resource_name: str | None = None, + permission_id: str | int | None = None, + resource_type: str | None = None, +) -> permission_types.Permission: + """ + This is the async version of `permission.get_permission`. + """ + name = _construct_name( + name=name, + resource_name=resource_name, + permission_id=permission_id, + resource_type=resource_type, + ) + return await permission_types.Permission.get_async(name=name, client=client) diff --git a/.venv/lib/python3.11/site-packages/google/generativeai/protos.py b/.venv/lib/python3.11/site-packages/google/generativeai/protos.py new file mode 100644 index 0000000000000000000000000000000000000000..010396c75bfa2f0ea74b9470fb37994fadc7a0e8 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/google/generativeai/protos.py @@ -0,0 +1,75 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This module provides low level access to the ProtoBuffer "Message" classes used by the API. + +**For typical usage of this SDK you do not need to use any of these classes.** + +ProtoBufers are Google API's serilization format. They are strongly typed and efficient. + +The `genai` SDK tries to be permissive about what objects it will accept from a user, but in the end +the SDK always converts input to an appropriate Proto Message object to send as the request. Each API request +has a `*Request` and `*Response` Message defined here. + +If you have any uncertainty about what the API may accept or return, these classes provide the +complete/unambiguous answer. They come from the `google-ai-generativelanguage` package which is +generated from a snapshot of the API definition. + +>>> from google.generativeai import protos +>>> import inspect +>>> print(inspect.getsource(protos.Part)) + +Proto classes can have "oneof" fields. Use `in` to check which `oneof` field is set. + +>>> p = protos.Part(text='hello') +>>> 'text' in p +True +>>> p.inline_data = {'mime_type':'image/png', 'data': b'PNG'} +>>> type(p.inline_data) is protos.Blob +True +>>> 'inline_data' in p +True +>>> 'text' in p +False + +Instances of all Message classes can be converted into JSON compatible dictionaries with the following construct +(Bytes are base64 encoded): + +>>> p_dict = type(p).to_dict(p) +>>> p_dict +{'inline_data': {'mime_type': 'image/png', 'data': 'UE5H'}} + +A compatible dict can be converted to an instance of a Message class by passing it as the first argument to the +constructor: + +>>> p = protos.Part(p_dict) +inline_data { + mime_type: "image/png" + data: "PNG" +} + +Note when converting that `to_dict` accepts additional arguments: + +- `use_integers_for_enums:bool = True`, Set it to `False` to replace enum int values with their string + names in the output +- ` including_default_value_fields:bool = True`, Set it to `False` to reduce the verbosity of the output. + +Additional arguments are described in the docstring: + +>>> help(proto.Part.to_dict) +""" + +from google.ai.generativelanguage_v1beta.types import * +from google.ai.generativelanguage_v1beta.types import __all__ diff --git a/.venv/lib/python3.11/site-packages/google/generativeai/py.typed b/.venv/lib/python3.11/site-packages/google/generativeai/py.typed new file mode 100644 index 0000000000000000000000000000000000000000..d57989efb98c16eaedb423051e60a5ab1d4ceabd --- /dev/null +++ b/.venv/lib/python3.11/site-packages/google/generativeai/py.typed @@ -0,0 +1 @@ +# see: https://peps.python.org/pep-0561/ diff --git a/.venv/lib/python3.11/site-packages/google/generativeai/responder.py b/.venv/lib/python3.11/site-packages/google/generativeai/responder.py new file mode 100644 index 0000000000000000000000000000000000000000..dd388c6a6a80964bc907249996f292e099dc2f59 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/google/generativeai/responder.py @@ -0,0 +1,512 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +from collections.abc import Iterable, Mapping, Sequence +import inspect +import typing +from typing import Any, Callable, Union +from typing_extensions import TypedDict + +import pydantic + +from google.generativeai import protos + +Type = protos.Type + +TypeOptions = Union[int, str, Type] + +_TYPE_TYPE: dict[TypeOptions, Type] = { + Type.TYPE_UNSPECIFIED: Type.TYPE_UNSPECIFIED, + 0: Type.TYPE_UNSPECIFIED, + "type_unspecified": Type.TYPE_UNSPECIFIED, + "unspecified": Type.TYPE_UNSPECIFIED, + Type.STRING: Type.STRING, + 1: Type.STRING, + "type_string": Type.STRING, + "string": Type.STRING, + Type.NUMBER: Type.NUMBER, + 2: Type.NUMBER, + "type_number": Type.NUMBER, + "number": Type.NUMBER, + Type.INTEGER: Type.INTEGER, + 3: Type.INTEGER, + "type_integer": Type.INTEGER, + "integer": Type.INTEGER, + Type.BOOLEAN: Type.BOOLEAN, + 4: Type.INTEGER, + "type_boolean": Type.BOOLEAN, + "boolean": Type.BOOLEAN, + Type.ARRAY: Type.ARRAY, + 5: Type.ARRAY, + "type_array": Type.ARRAY, + "array": Type.ARRAY, + Type.OBJECT: Type.OBJECT, + 6: Type.OBJECT, + "type_object": Type.OBJECT, + "object": Type.OBJECT, +} + + +def to_type(x: TypeOptions) -> Type: + if isinstance(x, str): + x = x.lower() + return _TYPE_TYPE[x] + + +def _generate_schema( + f: Callable[..., Any], + *, + descriptions: Mapping[str, str] | None = None, + required: Sequence[str] | None = None, +) -> dict[str, Any]: + """Generates the OpenAPI Schema for a python function. + + Args: + f: The function to generate an OpenAPI Schema for. + descriptions: Optional. A `{name: description}` mapping for annotating input + arguments of the function with user-provided descriptions. It + defaults to an empty dictionary (i.e. there will not be any + description for any of the inputs). + required: Optional. For the user to specify the set of required arguments in + function calls to `f`. If unspecified, it will be automatically + inferred from `f`. + + Returns: + dict[str, Any]: The OpenAPI Schema for the function `f` in JSON format. + """ + if descriptions is None: + descriptions = {} + if required is None: + required = [] + defaults = dict(inspect.signature(f).parameters) + fields_dict = { + name: ( + # 1. We infer the argument type here: use Any rather than None so + # it will not try to auto-infer the type based on the default value. + (param.annotation if param.annotation != inspect.Parameter.empty else Any), + pydantic.Field( + # 2. We do not support default values for now. + # default=( + # param.default if param.default != inspect.Parameter.empty + # else None + # ), + # 3. We support user-provided descriptions. + description=descriptions.get(name, None), + ), + ) + for name, param in defaults.items() + # We do not support *args or **kwargs + if param.kind + in ( + inspect.Parameter.POSITIONAL_OR_KEYWORD, + inspect.Parameter.KEYWORD_ONLY, + inspect.Parameter.POSITIONAL_ONLY, + ) + } + parameters = pydantic.create_model(f.__name__, **fields_dict).model_json_schema() + # Postprocessing + # 4. Suppress unnecessary title generation: + # * https://github.com/pydantic/pydantic/issues/1051 + # * http://cl/586221780 + parameters.pop("title", None) + for name, function_arg in parameters.get("properties", {}).items(): + function_arg.pop("title", None) + annotation = defaults[name].annotation + # 5. Nullable fields: + # * https://github.com/pydantic/pydantic/issues/1270 + # * https://stackoverflow.com/a/58841311 + # * https://github.com/pydantic/pydantic/discussions/4872 + if typing.get_origin(annotation) is typing.Union and type(None) in typing.get_args( + annotation + ): + function_arg["nullable"] = True + # 6. Annotate required fields. + if required: + # We use the user-provided "required" fields if specified. + parameters["required"] = required + else: + # Otherwise we infer it from the function signature. + parameters["required"] = [ + k + for k in defaults + if ( + defaults[k].default == inspect.Parameter.empty + and defaults[k].kind + in ( + inspect.Parameter.POSITIONAL_OR_KEYWORD, + inspect.Parameter.KEYWORD_ONLY, + inspect.Parameter.POSITIONAL_ONLY, + ) + ) + ] + schema = dict(name=f.__name__, description=f.__doc__, parameters=parameters) + return schema + + +def _rename_schema_fields(schema: dict[str, Any]): + if schema is None: + return schema + + schema = schema.copy() + + type_ = schema.pop("type", None) + if type_ is not None: + schema["type_"] = type_ + type_ = schema.get("type_", None) + if type_ is not None: + schema["type_"] = to_type(type_) + + format_ = schema.pop("format", None) + if format_ is not None: + schema["format_"] = format_ + + items = schema.pop("items", None) + if items is not None: + schema["items"] = _rename_schema_fields(items) + + properties = schema.pop("properties", None) + if properties is not None: + schema["properties"] = {k: _rename_schema_fields(v) for k, v in properties.items()} + + return schema + + +class FunctionDeclaration: + def __init__(self, *, name: str, description: str, parameters: dict[str, Any] | None = None): + """A class wrapping a `protos.FunctionDeclaration`, describes a function for `genai.GenerativeModel`'s `tools`.""" + self._proto = protos.FunctionDeclaration( + name=name, description=description, parameters=_rename_schema_fields(parameters) + ) + + @property + def name(self) -> str: + return self._proto.name + + @property + def description(self) -> str: + return self._proto.description + + @property + def parameters(self) -> protos.Schema: + return self._proto.parameters + + @classmethod + def from_proto(cls, proto) -> FunctionDeclaration: + self = cls(name="", description="", parameters={}) + self._proto = proto + return self + + def to_proto(self) -> protos.FunctionDeclaration: + return self._proto + + @staticmethod + def from_function(function: Callable[..., Any], descriptions: dict[str, str] | None = None): + """Builds a `CallableFunctionDeclaration` from a python function. + + The function should have type annotations. + + This method is able to generate the schema for arguments annotated with types: + + `AllowedTypes = float | int | str | list[AllowedTypes] | dict` + + This method does not yet build a schema for `TypedDict`, that would allow you to specify the dictionary + contents. But you can build these manually. + """ + + if descriptions is None: + descriptions = {} + + schema = _generate_schema(function, descriptions=descriptions) + + return CallableFunctionDeclaration(**schema, function=function) + + +StructType = dict[str, "ValueType"] +ValueType = Union[float, str, bool, StructType, list["ValueType"], None] + + +class CallableFunctionDeclaration(FunctionDeclaration): + """An extension of `FunctionDeclaration` that can be built from a Python function, and is callable. + + Note: The Python function must have type annotations. + """ + + def __init__( + self, + *, + name: str, + description: str, + parameters: dict[str, Any] | None = None, + function: Callable[..., Any], + ): + super().__init__(name=name, description=description, parameters=parameters) + self.function = function + + def __call__(self, fc: protos.FunctionCall) -> protos.FunctionResponse: + result = self.function(**fc.args) + if not isinstance(result, dict): + result = {"result": result} + return protos.FunctionResponse(name=fc.name, response=result) + + +FunctionDeclarationType = Union[ + FunctionDeclaration, + protos.FunctionDeclaration, + dict[str, Any], + Callable[..., Any], +] + + +def _make_function_declaration( + fun: FunctionDeclarationType, +) -> FunctionDeclaration | protos.FunctionDeclaration: + if isinstance(fun, (FunctionDeclaration, protos.FunctionDeclaration)): + return fun + elif isinstance(fun, dict): + if "function" in fun: + return CallableFunctionDeclaration(**fun) + else: + return FunctionDeclaration(**fun) + elif callable(fun): + return CallableFunctionDeclaration.from_function(fun) + else: + raise TypeError( + f"Invalid argument type: Expected an instance of `genai.FunctionDeclarationType`. Received type: {type(fun).__name__}.", + fun, + ) + + +def _encode_fd(fd: FunctionDeclaration | protos.FunctionDeclaration) -> protos.FunctionDeclaration: + if isinstance(fd, protos.FunctionDeclaration): + return fd + + return fd.to_proto() + + +class Tool: + """A wrapper for `protos.Tool`, Contains a collection of related `FunctionDeclaration` objects.""" + + def __init__(self, function_declarations: Iterable[FunctionDeclarationType]): + # The main path doesn't use this but is seems useful. + self._function_declarations = [_make_function_declaration(f) for f in function_declarations] + self._index = {} + for fd in self._function_declarations: + name = fd.name + if name in self._index: + raise ValueError("") + self._index[fd.name] = fd + + self._proto = protos.Tool( + function_declarations=[_encode_fd(fd) for fd in self._function_declarations] + ) + + @property + def function_declarations(self) -> list[FunctionDeclaration | protos.FunctionDeclaration]: + return self._function_declarations + + def __getitem__( + self, name: str | protos.FunctionCall + ) -> FunctionDeclaration | protos.FunctionDeclaration: + if not isinstance(name, str): + name = name.name + + return self._index[name] + + def __call__(self, fc: protos.FunctionCall) -> protos.FunctionResponse | None: + declaration = self[fc] + if not callable(declaration): + return None + + return declaration(fc) + + def to_proto(self): + return self._proto + + +class ToolDict(TypedDict): + function_declarations: list[FunctionDeclarationType] + + +ToolType = Union[ + Tool, protos.Tool, ToolDict, Iterable[FunctionDeclarationType], FunctionDeclarationType +] + + +def _make_tool(tool: ToolType) -> Tool: + if isinstance(tool, Tool): + return tool + elif isinstance(tool, protos.Tool): + return Tool(function_declarations=tool.function_declarations) + elif isinstance(tool, dict): + if "function_declarations" in tool: + return Tool(**tool) + else: + fd = tool + return Tool(function_declarations=[protos.FunctionDeclaration(**fd)]) + elif isinstance(tool, Iterable): + return Tool(function_declarations=tool) + else: + try: + return Tool(function_declarations=[tool]) + except Exception as e: + raise TypeError( + f"Invalid argument type: Expected an instance of `genai.ToolType`. Received type: {type(tool).__name__}.", + tool, + ) from e + + +class FunctionLibrary: + """A container for a set of `Tool` objects, manages lookup and execution of their functions.""" + + def __init__(self, tools: Iterable[ToolType]): + tools = _make_tools(tools) + self._tools = list(tools) + self._index = {} + for tool in self._tools: + for declaration in tool.function_declarations: + name = declaration.name + if name in self._index: + raise ValueError( + f"Invalid operation: A `FunctionDeclaration` named '{name}' is already defined. Each `FunctionDeclaration` must have a unique name." + ) + self._index[declaration.name] = declaration + + def __getitem__( + self, name: str | protos.FunctionCall + ) -> FunctionDeclaration | protos.FunctionDeclaration: + if not isinstance(name, str): + name = name.name + + return self._index[name] + + def __call__(self, fc: protos.FunctionCall) -> protos.Part | None: + declaration = self[fc] + if not callable(declaration): + return None + + response = declaration(fc) + return protos.Part(function_response=response) + + def to_proto(self): + return [tool.to_proto() for tool in self._tools] + + +ToolsType = Union[Iterable[ToolType], ToolType] + + +def _make_tools(tools: ToolsType) -> list[Tool]: + if isinstance(tools, Iterable) and not isinstance(tools, Mapping): + tools = [_make_tool(t) for t in tools] + if len(tools) > 1 and all(len(t.function_declarations) == 1 for t in tools): + # flatten into a single tool. + tools = [_make_tool([t.function_declarations[0] for t in tools])] + return tools + else: + tool = tools + return [_make_tool(tool)] + + +FunctionLibraryType = Union[FunctionLibrary, ToolsType] + + +def to_function_library(lib: FunctionLibraryType | None) -> FunctionLibrary | None: + if lib is None: + return lib + elif isinstance(lib, FunctionLibrary): + return lib + else: + return FunctionLibrary(tools=lib) + + +FunctionCallingMode = protos.FunctionCallingConfig.Mode + +# fmt: off +_FUNCTION_CALLING_MODE = { + 1: FunctionCallingMode.AUTO, + FunctionCallingMode.AUTO: FunctionCallingMode.AUTO, + "mode_auto": FunctionCallingMode.AUTO, + "auto": FunctionCallingMode.AUTO, + + 2: FunctionCallingMode.ANY, + FunctionCallingMode.ANY: FunctionCallingMode.ANY, + "mode_any": FunctionCallingMode.ANY, + "any": FunctionCallingMode.ANY, + + 3: FunctionCallingMode.NONE, + FunctionCallingMode.NONE: FunctionCallingMode.NONE, + "mode_none": FunctionCallingMode.NONE, + "none": FunctionCallingMode.NONE, +} +# fmt: on + +FunctionCallingModeType = Union[FunctionCallingMode, str, int] + + +def to_function_calling_mode(x: FunctionCallingModeType) -> FunctionCallingMode: + if isinstance(x, str): + x = x.lower() + return _FUNCTION_CALLING_MODE[x] + + +class FunctionCallingConfigDict(TypedDict): + mode: FunctionCallingModeType + allowed_function_names: list[str] + + +FunctionCallingConfigType = Union[ + FunctionCallingModeType, FunctionCallingConfigDict, protos.FunctionCallingConfig +] + + +def to_function_calling_config(obj: FunctionCallingConfigType) -> protos.FunctionCallingConfig: + if isinstance(obj, protos.FunctionCallingConfig): + return obj + elif isinstance(obj, (FunctionCallingMode, str, int)): + obj = {"mode": to_function_calling_mode(obj)} + elif isinstance(obj, dict): + obj = obj.copy() + mode = obj.pop("mode") + obj["mode"] = to_function_calling_mode(mode) + else: + raise TypeError( + "Invalid argument type: Could not convert input to `protos.FunctionCallingConfig`." + f" Received type: {type(obj).__name__}.", + obj, + ) + + return protos.FunctionCallingConfig(obj) + + +class ToolConfigDict: + function_calling_config: FunctionCallingConfigType + + +ToolConfigType = Union[ToolConfigDict, protos.ToolConfig] + + +def to_tool_config(obj: ToolConfigType) -> protos.ToolConfig: + if isinstance(obj, protos.ToolConfig): + return obj + elif isinstance(obj, dict): + fcc = obj.pop("function_calling_config") + fcc = to_function_calling_config(fcc) + obj["function_calling_config"] = fcc + return protos.ToolConfig(**obj) + else: + raise TypeError( + "Invalid argument type: Could not convert input to `protos.ToolConfig`. " + f"Received type: {type(obj).__name__}.", + ) diff --git a/.venv/lib/python3.11/site-packages/google/generativeai/retriever.py b/.venv/lib/python3.11/site-packages/google/generativeai/retriever.py new file mode 100644 index 0000000000000000000000000000000000000000..53c90140ab773d043fdb6e48a9e546749f011165 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/google/generativeai/retriever.py @@ -0,0 +1,255 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + + +from typing import AsyncIterable, Iterable, Optional + +import google.ai.generativelanguage as glm +from google.generativeai import protos + +from google.generativeai.client import get_default_retriever_client +from google.generativeai.client import get_default_retriever_async_client +from google.generativeai.types import helper_types +from google.generativeai.types.model_types import idecode_time +from google.generativeai.types import retriever_types + + +def create_corpus( + name: str | None = None, + display_name: str | None = None, + client: glm.RetrieverServiceClient | None = None, + request_options: helper_types.RequestOptionsType | None = None, +) -> retriever_types.Corpus: + """Calls the API to create a new `Corpus` by specifying either a corpus resource name as an ID or a display name, and returns the created `Corpus`. + + Args: + name: The corpus resource name (ID). The name must be alphanumeric and fewer + than 40 characters. + display_name: The human readable display name. The display name must be fewer + than 128 characters. All characters, including alphanumeric, spaces, and + dashes are supported. + request_options: Options for the request. + + Return: + `retriever_types.Corpus` object with specified name or display name. + + Raises: + ValueError: When the name is not specified or formatted incorrectly. + """ + if request_options is None: + request_options = {} + + if client is None: + client = get_default_retriever_client() + + if name is None: + corpus = protos.Corpus(display_name=display_name) + elif retriever_types.valid_name(name): + corpus = protos.Corpus(name=f"corpora/{name}", display_name=display_name) + else: + raise ValueError(retriever_types.NAME_ERROR_MSG.format(length=len(name), name=name)) + + request = protos.CreateCorpusRequest(corpus=corpus) + response = client.create_corpus(request, **request_options) + response = type(response).to_dict(response) + idecode_time(response, "create_time") + idecode_time(response, "update_time") + response = retriever_types.Corpus(**response) + return response + + +async def create_corpus_async( + name: str | None = None, + display_name: str | None = None, + client: glm.RetrieverServiceAsyncClient | None = None, + request_options: helper_types.RequestOptionsType | None = None, +) -> retriever_types.Corpus: + """This is the async version of `retriever.create_corpus`.""" + if request_options is None: + request_options = {} + + if client is None: + client = get_default_retriever_async_client() + + if name is None: + corpus = protos.Corpus(display_name=display_name) + elif retriever_types.valid_name(name): + corpus = protos.Corpus(name=f"corpora/{name}", display_name=display_name) + else: + raise ValueError(retriever_types.NAME_ERROR_MSG.format(length=len(name), name=name)) + + request = protos.CreateCorpusRequest(corpus=corpus) + response = await client.create_corpus(request, **request_options) + response = type(response).to_dict(response) + idecode_time(response, "create_time") + idecode_time(response, "update_time") + response = retriever_types.Corpus(**response) + return response + + +def get_corpus( + name: str, + client: glm.RetrieverServiceClient | None = None, + request_options: helper_types.RequestOptionsType | None = None, +) -> retriever_types.Corpus: # fmt: skip + """Calls the API to fetch a `Corpus` by name and returns the `Corpus`. + + Args: + name: The `Corpus` name. + request_options: Options for the request. + + Return: + a `retriever_types.Corpus` of interest. + """ + if request_options is None: + request_options = {} + + if client is None: + client = get_default_retriever_client() + + if "/" not in name: + name = "corpora/" + name + + request = protos.GetCorpusRequest(name=name) + response = client.get_corpus(request, **request_options) + response = type(response).to_dict(response) + idecode_time(response, "create_time") + idecode_time(response, "update_time") + response = retriever_types.Corpus(**response) + return response + + +async def get_corpus_async( + name: str, + client: glm.RetrieverServiceAsyncClient | None = None, + request_options: helper_types.RequestOptionsType | None = None, +) -> retriever_types.Corpus: # fmt: skip + """This is the async version of `retriever.get_corpus`.""" + + if request_options is None: + request_options = {} + + if client is None: + client = get_default_retriever_async_client() + + if "/" not in name: + name = "corpora/" + name + + request = protos.GetCorpusRequest(name=name) + response = await client.get_corpus(request, **request_options) + response = type(response).to_dict(response) + idecode_time(response, "create_time") + idecode_time(response, "update_time") + response = retriever_types.Corpus(**response) + return response + + +def delete_corpus( + name: str, + force: bool = False, + client: glm.RetrieverServiceClient | None = None, + request_options: helper_types.RequestOptionsType | None = None, +): # fmt: skip + """Calls the API to remove a `Corpus` from the service, optionally deleting associated `Document`s and objects if the `force` parameter is set to true. + + Args: + name: The `Corpus` name. + force: If set to true, any `Document`s and objects related to this `Corpus` will also be deleted. + request_options: Options for the request. + + """ + if request_options is None: + request_options = {} + + if client is None: + client = get_default_retriever_client() + + if "/" not in name: + name = "corpora/" + name + + request = protos.DeleteCorpusRequest(name=name, force=force) + client.delete_corpus(request, **request_options) + + +async def delete_corpus_async( + name: str, + force: bool = False, + client: glm.RetrieverServiceAsyncClient | None = None, + request_options: helper_types.RequestOptionsType | None = None, +): # fmt: skip + """This is the async version of `retriever.delete_corpus`.""" + if request_options is None: + request_options = {} + + if client is None: + client = get_default_retriever_async_client() + + if "/" not in name: + name = "corpora/" + name + + request = protos.DeleteCorpusRequest(name=name, force=force) + await client.delete_corpus(request, **request_options) + + +def list_corpora( + *, + page_size: Optional[int] = None, + client: glm.RetrieverServiceClient | None = None, + request_options: helper_types.RequestOptionsType | None = None, +) -> Iterable[retriever_types.Corpus]: + """Calls the API to list all `Corpora` in the service and returns a list of paginated `Corpora`. + + Args: + page_size: Maximum number of `Corpora` to request. + page_token: A page token, received from a previous ListCorpora call. + request_options: Options for the request. + + Return: + Paginated list of `Corpora`. + """ + if request_options is None: + request_options = {} + + if client is None: + client = get_default_retriever_client() + + request = protos.ListCorporaRequest(page_size=page_size) + for corpus in client.list_corpora(request, **request_options): + corpus = type(corpus).to_dict(corpus) + idecode_time(corpus, "create_time") + idecode_time(corpus, "update_time") + yield retriever_types.Corpus(**corpus) + + +async def list_corpora_async( + *, + page_size: Optional[int] = None, + client: glm.RetrieverServiceClient | None = None, + request_options: helper_types.RequestOptionsType | None = None, +) -> AsyncIterable[retriever_types.Corpus]: + """This is the async version of `retriever.list_corpora`.""" + if request_options is None: + request_options = {} + + if client is None: + client = get_default_retriever_async_client() + + request = protos.ListCorporaRequest(page_size=page_size) + async for corpus in await client.list_corpora(request, **request_options): + corpus = type(corpus).to_dict(corpus) + idecode_time(corpus, "create_time") + idecode_time(corpus, "update_time") + yield retriever_types.Corpus(**corpus) diff --git a/.venv/lib/python3.11/site-packages/google/generativeai/string_utils.py b/.venv/lib/python3.11/site-packages/google/generativeai/string_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..049bb885b82a0436d2cca777a0dcdda4c1900f9c --- /dev/null +++ b/.venv/lib/python3.11/site-packages/google/generativeai/string_utils.py @@ -0,0 +1,84 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import dataclasses +import pprint +import re +import reprlib +import textwrap + + +def set_doc(doc): + """A decorator to set the docstring of a function.""" + + def inner(f): + f.__doc__ = doc + return f + + return inner + + +def strip_oneof(docstring): + lines = docstring.splitlines() + lines = [line for line in lines if ".. _oneof:" not in line] + lines = [line for line in lines if "This field is a member of `oneof`_" not in line] + return "\n".join(lines) + + +def prettyprint(cls): + cls.__str__ = _prettyprint + cls.__repr__ = _prettyprint + return cls + + +repr = reprlib.Repr() + + +@reprlib.recursive_repr() +def _prettyprint(self): + """A dataclass prettyprint function you can use in __str__or __repr__. + + Note: You can't set `__str__ = pprint.pformat` because it causes a recursion error. + + Mostly identical to pprint but: + + * This will contract long lists and dicts (> 10lines) to [...] and {...}. + * This will contract long object reprs to ClassName(...). + """ + fields = [] + for f in dataclasses.fields(self): + s = pprint.pformat(getattr(self, f.name)) + class_re = r"^(\w+)\(.*\)$" + if s.count("\n") >= 10: + if s.startswith("["): + s = "[...]" + elif s.startswith("{"): + s = "{...}" + elif re.match(class_re, s, flags=re.DOTALL): + s = re.sub(class_re, r"\1(...)", s, flags=re.DOTALL) + else: + s = "..." + else: + width = len(f.name) + 1 + s = textwrap.indent(s, " " * width).lstrip(" ") + fields.append(f"{f.name}={s}") + attrs = ",\n".join(fields) + + name = self.__class__.__name__ + width = len(name) + 1 + + attrs = textwrap.indent(attrs, " " * width).lstrip(" ") + return f"{name}({attrs})" diff --git a/.venv/lib/python3.11/site-packages/google/generativeai/text.py b/.venv/lib/python3.11/site-packages/google/generativeai/text.py new file mode 100644 index 0000000000000000000000000000000000000000..3a147f945e3a4b8f2b3af669b68aafb462abce5b --- /dev/null +++ b/.venv/lib/python3.11/site-packages/google/generativeai/text.py @@ -0,0 +1,342 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import dataclasses +from collections.abc import Iterable, Sequence +import itertools +from typing import Any, Iterable, overload, TypeVar + +import google.ai.generativelanguage as glm + +from google.generativeai.client import get_default_text_client +from google.generativeai import string_utils +from google.generativeai.types import text_types +from google.generativeai.types import model_types +from google.generativeai import models +from google.generativeai.types import safety_types + +DEFAULT_TEXT_MODEL = "models/text-bison-001" +EMBEDDING_MAX_BATCH_SIZE = 100 + +try: + # python 3.12+ + _batched = itertools.batched # type: ignore +except AttributeError: + T = TypeVar("T") + + def _batched(iterable: Iterable[T], n: int) -> Iterable[list[T]]: + if n < 1: + raise ValueError(f"Batch size `n` must be >1, got: {n}") + batch = [] + for item in iterable: + batch.append(item) + if len(batch) == n: + yield batch + batch = [] + + if batch: + yield batch + + +def _make_text_prompt(prompt: str | dict[str, str]) -> glm.TextPrompt: + """ + Creates a `glm.TextPrompt` object based on the provided prompt input. + + Args: + prompt: The prompt input, either a string or a dictionary. + + Returns: + glm.TextPrompt: A TextPrompt object containing the prompt text. + + Raises: + TypeError: If the provided prompt is neither a string nor a dictionary. + """ + if isinstance(prompt, str): + return glm.TextPrompt(text=prompt) + elif isinstance(prompt, dict): + return glm.TextPrompt(prompt) + else: + TypeError("Expected string or dictionary for text prompt.") + + +def _make_generate_text_request( + *, + model: model_types.AnyModelNameOptions = DEFAULT_TEXT_MODEL, + prompt: str | None = None, + temperature: float | None = None, + candidate_count: int | None = None, + max_output_tokens: int | None = None, + top_p: int | None = None, + top_k: int | None = None, + safety_settings: safety_types.SafetySettingOptions | None = None, + stop_sequences: str | Iterable[str] | None = None, +) -> glm.GenerateTextRequest: + """ + Creates a `glm.GenerateTextRequest` object based on the provided parameters. + + This function generates a `glm.GenerateTextRequest` object with the specified + parameters. It prepares the input parameters and creates a request that can be + used for generating text using the chosen model. + + Args: + model: The model to use for text generation. + prompt: The prompt for text generation. Defaults to None. + temperature: The temperature for randomness in generation. Defaults to None. + candidate_count: The number of candidates to consider. Defaults to None. + max_output_tokens: The maximum number of output tokens. Defaults to None. + top_p: The nucleus sampling probability threshold. Defaults to None. + top_k: The top-k sampling parameter. Defaults to None. + safety_settings: Safety settings for generated text. Defaults to None. + stop_sequences: Stop sequences to halt text generation. Can be a string + or iterable of strings. Defaults to None. + + Returns: + `glm.GenerateTextRequest`: A `GenerateTextRequest` object configured with the specified parameters. + """ + model = model_types.make_model_name(model) + prompt = _make_text_prompt(prompt=prompt) + safety_settings = safety_types.normalize_safety_settings( + safety_settings, harm_category_set="old" + ) + if isinstance(stop_sequences, str): + stop_sequences = [stop_sequences] + if stop_sequences: + stop_sequences = list(stop_sequences) + + return glm.GenerateTextRequest( + model=model, + prompt=prompt, + temperature=temperature, + candidate_count=candidate_count, + max_output_tokens=max_output_tokens, + top_p=top_p, + top_k=top_k, + safety_settings=safety_settings, + stop_sequences=stop_sequences, + ) + + +def generate_text( + *, + model: model_types.AnyModelNameOptions = DEFAULT_TEXT_MODEL, + prompt: str, + temperature: float | None = None, + candidate_count: int | None = None, + max_output_tokens: int | None = None, + top_p: float | None = None, + top_k: float | None = None, + safety_settings: safety_types.SafetySettingOptions | None = None, + stop_sequences: str | Iterable[str] | None = None, + client: glm.TextServiceClient | None = None, + request_options: dict[str, Any] | None = None, +) -> text_types.Completion: + """Calls the API and returns a `types.Completion` containing the response. + + Args: + model: Which model to call, as a string or a `types.Model`. + prompt: Free-form input text given to the model. Given a prompt, the model will + generate text that completes the input text. + temperature: Controls the randomness of the output. Must be positive. + Typical values are in the range: `[0.0,1.0]`. Higher values produce a + more random and varied response. A temperature of zero will be deterministic. + candidate_count: The **maximum** number of generated response messages to return. + This value must be between `[1, 8]`, inclusive. If unset, this + will default to `1`. + + Note: Only unique candidates are returned. Higher temperatures are more + likely to produce unique candidates. Setting `temperature=0.0` will always + return 1 candidate regardless of the `candidate_count`. + max_output_tokens: Maximum number of tokens to include in a candidate. Must be greater + than zero. If unset, will default to 64. + top_k: The API uses combined [nucleus](https://arxiv.org/abs/1904.09751) and top-k sampling. + `top_k` sets the maximum number of tokens to sample from on each step. + top_p: The API uses combined [nucleus](https://arxiv.org/abs/1904.09751) and top-k sampling. + `top_p` configures the nucleus sampling. It sets the maximum cumulative + probability of tokens to sample from. + For example, if the sorted probabilities are + `[0.5, 0.2, 0.1, 0.1, 0.05, 0.05]` a `top_p` of `0.8` will sample + as `[0.625, 0.25, 0.125, 0, 0, 0]`. + safety_settings: A list of unique `types.SafetySetting` instances for blocking unsafe content. + These will be enforced on the `prompt` and + `candidates`. There should not be more than one + setting for each `types.SafetyCategory` type. The API will block any prompts and + responses that fail to meet the thresholds set by these settings. This list + overrides the default settings for each `SafetyCategory` specified in the + safety_settings. If there is no `types.SafetySetting` for a given + `SafetyCategory` provided in the list, the API will use the default safety + setting for that category. + stop_sequences: A set of up to 5 character sequences that will stop output generation. + If specified, the API will stop at the first appearance of a stop + sequence. The stop sequence will not be included as part of the response. + client: If you're not relying on a default client, you pass a `glm.TextServiceClient` instead. + request_options: Options for the request. + + Returns: + A `types.Completion` containing the model's text completion response. + """ + request = _make_generate_text_request( + model=model, + prompt=prompt, + temperature=temperature, + candidate_count=candidate_count, + max_output_tokens=max_output_tokens, + top_p=top_p, + top_k=top_k, + safety_settings=safety_settings, + stop_sequences=stop_sequences, + ) + + return _generate_response(client=client, request=request, request_options=request_options) + + +@string_utils.prettyprint +@dataclasses.dataclass(init=False) +class Completion(text_types.Completion): + def __init__(self, **kwargs): + for key, value in kwargs.items(): + setattr(self, key, value) + + self.result = None + if self.candidates: + self.result = self.candidates[0]["output"] + + +def _generate_response( + request: glm.GenerateTextRequest, + client: glm.TextServiceClient = None, + request_options: dict[str, Any] | None = None, +) -> Completion: + """ + Generates a response using the provided `glm.GenerateTextRequest` and client. + + Args: + request: The text generation request. + client: The client to use for text generation. Defaults to None, in which + case the default text client is used. + request_options: Options for the request. + + Returns: + `Completion`: A `Completion` object with the generated text and response information. + """ + if request_options is None: + request_options = {} + + if client is None: + client = get_default_text_client() + + response = client.generate_text(request, **request_options) + response = type(response).to_dict(response) + + response["filters"] = safety_types.convert_filters_to_enums(response["filters"]) + response["safety_feedback"] = safety_types.convert_safety_feedback_to_enums( + response["safety_feedback"] + ) + response["candidates"] = safety_types.convert_candidate_enums(response["candidates"]) + + return Completion(_client=client, **response) + + +def count_text_tokens( + model: model_types.AnyModelNameOptions, + prompt: str, + client: glm.TextServiceClient | None = None, + request_options: dict[str, Any] | None = None, +) -> text_types.TokenCount: + base_model = models.get_base_model_name(model) + + if request_options is None: + request_options = {} + + if client is None: + client = get_default_text_client() + + result = client.count_text_tokens( + glm.CountTextTokensRequest(model=base_model, prompt={"text": prompt}), + **request_options, + ) + + return type(result).to_dict(result) + + +@overload +def generate_embeddings( + model: model_types.BaseModelNameOptions, + text: str, + client: glm.TextServiceClient = None, + request_options: dict[str, Any] | None = None, +) -> text_types.EmbeddingDict: ... + + +@overload +def generate_embeddings( + model: model_types.BaseModelNameOptions, + text: Sequence[str], + client: glm.TextServiceClient = None, + request_options: dict[str, Any] | None = None, +) -> text_types.BatchEmbeddingDict: ... + + +def generate_embeddings( + model: model_types.BaseModelNameOptions, + text: str | Sequence[str], + client: glm.TextServiceClient = None, + request_options: dict[str, Any] | None = None, +) -> text_types.EmbeddingDict | text_types.BatchEmbeddingDict: + """Calls the API to create an embedding for the text passed in. + + Args: + model: Which model to call, as a string or a `types.Model`. + + text: Free-form input text given to the model. Given a string, the model will + generate an embedding based on the input text. + + client: If you're not relying on a default client, you pass a `glm.TextServiceClient` instead. + + request_options: Options for the request. + + Returns: + Dictionary containing the embedding (list of float values) for the input text. + """ + model = model_types.make_model_name(model) + + if request_options is None: + request_options = {} + + if client is None: + client = get_default_text_client() + + if isinstance(text, str): + embedding_request = glm.EmbedTextRequest(model=model, text=text) + embedding_response = client.embed_text( + embedding_request, + **request_options, + ) + embedding_dict = type(embedding_response).to_dict(embedding_response) + embedding_dict["embedding"] = embedding_dict["embedding"]["value"] + else: + result = {"embedding": []} + for batch in _batched(text, EMBEDDING_MAX_BATCH_SIZE): + # TODO(markdaoust): This could use an option for returning an iterator or wait-bar. + embedding_request = glm.BatchEmbedTextRequest(model=model, texts=batch) + embedding_response = client.batch_embed_text( + embedding_request, + **request_options, + ) + embedding_dict = type(embedding_response).to_dict(embedding_response) + result["embedding"].extend(e["value"] for e in embedding_dict["embeddings"]) + return result + + return embedding_dict diff --git a/.venv/lib/python3.11/site-packages/google/generativeai/utils.py b/.venv/lib/python3.11/site-packages/google/generativeai/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..cd2c4cbf7f00979367b6ee383b0e894cdf063829 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/google/generativeai/utils.py @@ -0,0 +1,29 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + + +def flatten_update_paths(updates): + """Flattens a nested dictionary into a single level dictionary, with keys representing the original path.""" + + new_updates = {} + for key, value in updates.items(): + if isinstance(value, dict): + for sub_key, sub_value in flatten_update_paths(value).items(): + new_updates[f"{key}.{sub_key}"] = sub_value + else: + new_updates[key] = value + + return new_updates diff --git a/.venv/lib/python3.11/site-packages/google/generativeai/version.py b/.venv/lib/python3.11/site-packages/google/generativeai/version.py new file mode 100644 index 0000000000000000000000000000000000000000..b5271a21db6391dc07a454b98b60c2f7f940b73d --- /dev/null +++ b/.venv/lib/python3.11/site-packages/google/generativeai/version.py @@ -0,0 +1,17 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +__version__ = "0.8.4" diff --git a/.venv/lib/python3.11/site-packages/google/generativeai/vision_models/__init__.py b/.venv/lib/python3.11/site-packages/google/generativeai/vision_models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3f733dd6e11d79a79f2238c17db6777b2a83b759 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/google/generativeai/vision_models/__init__.py @@ -0,0 +1,31 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +"""Classes for working with vision models.""" + +from google.generativeai.types.image_types import check_watermark, Image, GeneratedImage + +from google.generativeai.vision_models._vision_models import ( + ImageGenerationModel, + ImageGenerationResponse, + VideoGenerationModel, +) + +__all__ = [ + "Image", + "GeneratedImage", + "ImageGenerationModel", + "ImageGenerationResponse", + "VideoGenerationModel", +] diff --git a/.venv/lib/python3.11/site-packages/google/generativeai/vision_models/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/google/generativeai/vision_models/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..05f93acc8c422106683622a6686975b53d5ba610 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/google/generativeai/vision_models/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/google/generativeai/vision_models/__pycache__/_vision_models.cpython-311.pyc b/.venv/lib/python3.11/site-packages/google/generativeai/vision_models/__pycache__/_vision_models.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..224cd6d904e854a084c68bb273a8adff402ad108 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/google/generativeai/vision_models/__pycache__/_vision_models.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/google/generativeai/vision_models/_vision_models.py b/.venv/lib/python3.11/site-packages/google/generativeai/vision_models/_vision_models.py new file mode 100644 index 0000000000000000000000000000000000000000..31dd06b35d310251dc6f1f05a575b98d9ffb54aa --- /dev/null +++ b/.venv/lib/python3.11/site-packages/google/generativeai/vision_models/_vision_models.py @@ -0,0 +1,456 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# pylint: disable=bad-continuation, line-too-long, protected-access +"""Classes for working with vision models.""" + +import base64 +import dataclasses + +import pathlib +import typing +from typing import Iterator, List, Literal, Optional, Union + +from google.generativeai import client +from google.generativeai.types import image_types + +from google.generativeai import files +from google.generativeai import protos +from google.generativeai import operations + +ImageAspectRatio = Literal["1:1", "9:16", "16:9", "4:3", "3:4"] +IMAGE_ASPECT_RATIOS = ImageAspectRatio.__args__ # type: ignore + +VideoAspectRatio = Literal["9:16", "16:9"] +VIDEO_ASPECT_RATIOS = VideoAspectRatio.__args__ # type: ignore + +OutputMimeType = Literal["image/png", "image/jpeg"] +OUTPUT_MIME_TYPES = OutputMimeType.__args__ # type: ignore + +SafetyFilterLevel = Literal["block_low_and_above", "block_medium_and_above", "block_only_high"] +SAFETY_FILTER_LEVELS = SafetyFilterLevel.__args__ # type: ignore + +PersonGeneration = Literal["dont_allow", "allow_adult"] +PERSON_GENERATIONS = PersonGeneration.__args__ # type: ignore + + +class ImageGenerationModel: + """Generates images from text prompt. + + Examples:: + + model = ImageGenerationModel.from_pretrained("imagegeneration@002") + response = model.generate_images( + prompt="Astronaut riding a horse", + # Optional: + number_of_images=1, + ) + response[0].show() + response[0].save("image1.png") + """ + + def __init__(self, model_id: str): + if not model_id.startswith("models"): + model_id = f"models/{model_id}" + self.model_name = model_id + self._client = None + + @classmethod + def from_pretrained(cls, model_name: str = "imagen-3.0-generate-001"): + """For vertex compatibility""" + return cls(model_name) + + def _generate_images( + self, + prompt: str, + *, + negative_prompt: Optional[str] = None, + number_of_images: int = 1, + width: Optional[int] = None, + height: Optional[int] = None, + aspect_ratio: Optional[ImageAspectRatio] = None, + guidance_scale: Optional[float] = None, + output_mime_type: Optional[OutputMimeType] = None, + compression_quality: Optional[float] = None, + language: Optional[str] = None, + safety_filter_level: Optional[SafetyFilterLevel] = None, + person_generation: Optional[PersonGeneration] = None, + ) -> "ImageGenerationResponse": + """Generates images from text prompt. + + Args: + prompt: Text prompt for the image. + negative_prompt: A description of what you want to omit in the generated + images. + number_of_images: Number of images to generate. Range: 1..8. + width: Width of the image. One of the sizes must be 256 or 1024. + height: Height of the image. One of the sizes must be 256 or 1024. + aspect_ratio: Aspect ratio for the image. Supported values are: + * 1:1 - Square image + * 9:16 - Portait image + * 16:9 - Landscape image + * 4:3 - Landscape, desktop ratio. + * 3:4 - Portrait, desktop ratio + guidance_scale: Controls the strength of the prompt. Suggested values + are - * 0-9 (low strength) * 10-20 (medium strength) * 21+ (high + strength) + output_mime_type: Which image format should the output be saved as. + Supported values: * image/png: Save as a PNG image * image/jpeg: Save + as a JPEG image + compression_quality: Level of compression if the output mime type is + selected to be image/jpeg. Float between 0 to 100 + language: Language of the text prompt for the image. Default: None. + Supported values are `"en"` for English, `"hi"` for Hindi, `"ja"` for + Japanese, `"ko"` for Korean, and `"auto"` for automatic language + detection. + safety_filter_level: Adds a filter level to Safety filtering. Supported + values are: + * "block_most" : Strongest filtering level, most strict blocking + * "block_some" : Block some problematic prompts and responses + * "block_few" : Block fewer problematic prompts and responses + person_generation: Allow generation of people by the model Supported + values are: + * "dont_allow" : Block generation of people + * "allow_adult" : Generate adults, but not children + + Returns: + An `ImageGenerationResponse` object. + """ + if self._client is None: + self._client = client.get_default_prediction_client() + # Note: Only a single prompt is supported by the service. + instance = {"prompt": prompt} + shared_generation_parameters = { + "prompt": prompt, + # b/295946075 The service stopped supporting image sizes. + # "width": width, + # "height": height, + "number_of_images_in_batch": number_of_images, + } + + parameters = {} + max_size = max(width or 0, height or 0) or None + if aspect_ratio is not None: + if aspect_ratio not in IMAGE_ASPECT_RATIOS: + raise ValueError(f"aspect_ratio not in {IMAGE_ASPECT_RATIOS}") + parameters["aspectRatio"] = aspect_ratio + elif max_size: + # Note: The size needs to be a string + parameters["sampleImageSize"] = str(max_size) + if height is not None and width is not None and height != width: + parameters["aspectRatio"] = f"{width}:{height}" + + parameters["sampleCount"] = number_of_images + if negative_prompt: + parameters["negativePrompt"] = negative_prompt + shared_generation_parameters["negative_prompt"] = negative_prompt + + if guidance_scale is not None: + parameters["guidanceScale"] = guidance_scale + shared_generation_parameters["guidance_scale"] = guidance_scale + + if language is not None: + parameters["language"] = language + shared_generation_parameters["language"] = language + + parameters["outputOptions"] = {} + if output_mime_type is not None: + if output_mime_type not in OUTPUT_MIME_TYPES: + raise ValueError(f"output_mime_type not in {OUTPUT_MIME_TYPES}") + parameters["outputOptions"]["mimeType"] = output_mime_type + shared_generation_parameters["mime_type"] = output_mime_type + + if compression_quality is not None: + parameters["outputOptions"]["compressionQuality"] = compression_quality + shared_generation_parameters["compression_quality"] = compression_quality + + if safety_filter_level is not None: + if safety_filter_level not in SAFETY_FILTER_LEVELS: + raise ValueError(f"safety_filter_level not in {SAFETY_FILTER_LEVELS}") + parameters["safetySetting"] = safety_filter_level + shared_generation_parameters["safety_filter_level"] = safety_filter_level + + if person_generation is not None: + parameters["personGeneration"] = person_generation + shared_generation_parameters["person_generation"] = person_generation + + response = self._client.predict( + model=self.model_name, instances=[instance], parameters=parameters + ) + + generated_images: List[image_types.GeneratedImage] = [] + for idx, prediction in enumerate(response.predictions): + generation_parameters = dict(shared_generation_parameters) + generation_parameters["index_of_image_in_batch"] = idx + encoded_bytes = prediction.get("bytesBase64Encoded") + generated_image = image_types.GeneratedImage( + image_bytes=base64.b64decode(encoded_bytes) if encoded_bytes else None, + generation_parameters=generation_parameters, + ) + generated_images.append(generated_image) + + return ImageGenerationResponse(images=generated_images) + + def generate_images( + self, + prompt: str, + *, + negative_prompt: Optional[str] = None, + number_of_images: int = 1, + aspect_ratio: Optional[ImageAspectRatio] = None, + guidance_scale: Optional[float] = None, + language: Optional[str] = None, + safety_filter_level: Optional[SafetyFilterLevel] = None, + person_generation: Optional[PersonGeneration] = None, + ) -> "ImageGenerationResponse": + """Generates images from text prompt. + + Args: + prompt: Text prompt for the image. + negative_prompt: A description of what you want to omit in the generated + images. + number_of_images: Number of images to generate. Range: 1..8. + aspect_ratio: Changes the aspect ratio of the generated image Supported + values are: + * "1:1" : 1:1 aspect ratio + * "9:16" : 9:16 aspect ratio + * "16:9" : 16:9 aspect ratio + * "4:3" : 4:3 aspect ratio + * "3:4" : 3:4 aspect_ratio + guidance_scale: Controls the strength of the prompt. Suggested values are: + * 0-9 (low strength) + * 10-20 (medium strength) + * 21+ (high strength) + language: Language of the text prompt for the image. Default: None. + Supported values are `"en"` for English, `"hi"` for Hindi, `"ja"` + for Japanese, `"ko"` for Korean, and `"auto"` for automatic language + detection. + safety_filter_level: Adds a filter level to Safety filtering. Supported + values are: + * "block_most" : Strongest filtering level, most strict + blocking + * "block_some" : Block some problematic prompts and responses + * "block_few" : Block fewer problematic prompts and responses + person_generation: Allow generation of people by the model Supported + values are: + * "dont_allow" : Block generation of people + * "allow_adult" : Generate adults, but not children + Returns: + An `ImageGenerationResponse` object. + """ + return self._generate_images( + prompt=prompt, + negative_prompt=negative_prompt, + number_of_images=number_of_images, + aspect_ratio=aspect_ratio, + guidance_scale=guidance_scale, + language=language, + safety_filter_level=safety_filter_level, + person_generation=person_generation, + ) + + +@dataclasses.dataclass +class ImageGenerationResponse: + """Image generation response. + + Attributes: + images: The list of generated images. + """ + + __module__ = "vertexai.preview.vision_models" + + images: List[image_types.GeneratedImage] + + def __iter__(self) -> typing.Iterator[image_types.GeneratedImage]: + """Iterates through the generated images.""" + yield from self.images + + def __getitem__(self, idx: int) -> image_types.GeneratedImage: + """Gets the generated image by index.""" + return self.images[idx] + + +class VideoGenerationModel: + """Generates videos from a text prompt. + + Examples: + + ``` + model = VideoGenerationModel("veo-001-preview-0815") + response = model.generate_videos( + prompt="Astronaut riding a horse", + ) + response[0].save("video.mp4") + ``` + """ + + def __init__(self, model_id: str): + if not model_id.startswith("models"): + model_id = f"models/{model_id}" + self.model_name = model_id + self._client = None + + @classmethod + def from_pretrained(cls, model_name: str): + """For vertex compatibility""" + return cls(model_name) + + def _generate_videos( + self, + prompt: str, + *, + image: Union[pathlib.Path, image_types.ImageType] = None, + aspect_ratio: Optional[VideoAspectRatio] = None, + person_generation: Optional[PersonGeneration] = None, + number_of_videos: int = 4, + ) -> "VideoGenerationResponse": + """Generates videos from a text prompt. + + Args: + prompt: Text prompt for the video. + aspect_ratio: Aspect ratio for the video. Supported values are: + * 9:16 - Portait + * 16:9 - Landscape + person_generation: Allow generation of people by the model. + Supported values are: + * "dont_allow" : Block generation of people + * "allow_adult" : Generate adults, but not children + number_of_videos: Number of videos to generate. Maximum 4. + Returns: + An `VideoGenerationResponse` object. + """ + if self._client is None: + self._client = client.get_default_prediction_client() + # Note: Only a single prompt is supported by the service. + + instance = {} + if prompt is not None: + instance.update({"prompt": prompt}) + if image is not None: + img = image_types.to_image(image) + instance.update( + { + "image": { + "bytesBase64Encoded": base64.b64encode(img._loaded_bytes).decode(), + "mimeType": img._mime_type, + } + } + ) + + parameters = {} + + if aspect_ratio is not None: + if aspect_ratio not in VIDEO_ASPECT_RATIOS: + raise ValueError(f"aspect_ratio not in {VIDEO_ASPECT_RATIOS}") + parameters["aspectRatio"] = aspect_ratio + + if person_generation is not None: + parameters["personGeneration"] = person_generation + + parameters["sampleCount"] = number_of_videos + + operation = self._client.predict_long_running( + model=self.model_name, instances=[instance], parameters=parameters + ) + operation = GenerateVideoOperation.from_core_operation(operation) + + return operation + + def generate_videos( + self, + prompt: str, + *, + image: Union[pathlib.Path, image_types.ImageType] = None, + aspect_ratio: Optional[VideoAspectRatio] = None, + person_generation: Optional[PersonGeneration] = None, + number_of_videos: int = 4, + ) -> "VideoGenerationResponse": + """Generates videos from a text prompt. + + Args: + prompt: Text prompt for the video. + aspect_ratio: Changes the aspect ratio of the generated video Supported + values are: + * "9:16" : 9:16 aspect ratio + * "16:9" : 16:9 aspect ratio + person_generation: Allow generation of people by the model. + Supported values are: + * "dont_allow" : Block generation of people + * "allow_adult" : Generate adults, but not children + number_of_videos: Number of videos to generate. Maximum 4. + Returns: + An `VideoGenerationResponse` object. + """ + return self._generate_videos( + prompt=prompt, + image=image, + aspect_ratio=aspect_ratio, + person_generation=person_generation, + number_of_videos=number_of_videos, + ) + + +@dataclasses.dataclass +class VideoGenerationResponse: + """Video generation response. + + Attributes: + videos: The list of generated videos. + """ + + videos: List["image_types.Video"] + rai_media_filtered_reasons: Optional[list[str]] = None + + def __iter__(self) -> typing.Iterator["image_types.Video"]: + """Iterates through the generated videos.""" + yield from self.videos + + def __getitem__(self, idx: int) -> "image_types.Video": + """Gets the generated video by index.""" + try: + return self.videos[idx] + except IndexError as e: + if bool(self.rai_media_filtered_reasons): + e.args = ( + f"{e.args[0]}\n\n" + f"Note: Some videos were filtered out: {self.rai_media_filtered_reasons=}", + ) + raise e + + +class GenerateVideoOperation(operations.BaseOperation): + def set_result(self, result): + response = result.generate_video_response + samples = result.generate_video_response.generated_samples + + videos = [] + for sample in samples: + video = image_types.GeneratedVideo(uri=sample.video.uri) + videos.append(video) + result = VideoGenerationResponse( + videos, rai_media_filtered_reasons=list(response.rai_media_filtered_reasons) + ) + super().set_result(result) + + def wait_bar(self, **kwargs) -> Iterator[protos.PredictLongRunningMetadata]: + """A tqdm wait bar, yields `Operation` statuses until complete. + + Args: + **kwargs: passed through to `tqdm.auto.tqdm(..., **kwargs)` + + Yields: + Operation statuses as `protos.PredictLongRunningMetadata` objects. + """ + return super().wait_bar(**kwargs) diff --git a/.venv/lib/python3.11/site-packages/google/longrunning/__pycache__/operations_pb2.cpython-311.pyc b/.venv/lib/python3.11/site-packages/google/longrunning/__pycache__/operations_pb2.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..94368ff7124b226697f1ab2546718675a42eedb0 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/google/longrunning/__pycache__/operations_pb2.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/google/longrunning/__pycache__/operations_proto_pb2.cpython-311.pyc b/.venv/lib/python3.11/site-packages/google/longrunning/__pycache__/operations_proto_pb2.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..719673446f9ba3011d23574f0f20d0747565e31e Binary files /dev/null and b/.venv/lib/python3.11/site-packages/google/longrunning/__pycache__/operations_proto_pb2.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/google/rpc/__pycache__/code_pb2.cpython-311.pyc b/.venv/lib/python3.11/site-packages/google/rpc/__pycache__/code_pb2.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3205225b5196b6a8925209bf1bc02426bf642945 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/google/rpc/__pycache__/code_pb2.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/google/rpc/__pycache__/error_details_pb2.cpython-311.pyc b/.venv/lib/python3.11/site-packages/google/rpc/__pycache__/error_details_pb2.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8d7476727ea7e31320988b4a804cafa8b8e3ad53 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/google/rpc/__pycache__/error_details_pb2.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/google/rpc/__pycache__/http_pb2.cpython-311.pyc b/.venv/lib/python3.11/site-packages/google/rpc/__pycache__/http_pb2.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d76e0366e1019843abc0f24156033826544777b4 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/google/rpc/__pycache__/http_pb2.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/google/rpc/__pycache__/status_pb2.cpython-311.pyc b/.venv/lib/python3.11/site-packages/google/rpc/__pycache__/status_pb2.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..33baf6688e492a44df66d589a47e77ce5bf982e2 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/google/rpc/__pycache__/status_pb2.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/google/rpc/code.proto b/.venv/lib/python3.11/site-packages/google/rpc/code.proto new file mode 100644 index 0000000000000000000000000000000000000000..ba8f2bf9eb2139543fcb8bbfe1acc33240b4475c --- /dev/null +++ b/.venv/lib/python3.11/site-packages/google/rpc/code.proto @@ -0,0 +1,186 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto3"; + +package google.rpc; + +option go_package = "google.golang.org/genproto/googleapis/rpc/code;code"; +option java_multiple_files = true; +option java_outer_classname = "CodeProto"; +option java_package = "com.google.rpc"; +option objc_class_prefix = "RPC"; + +// The canonical error codes for gRPC APIs. +// +// +// Sometimes multiple error codes may apply. Services should return +// the most specific error code that applies. For example, prefer +// `OUT_OF_RANGE` over `FAILED_PRECONDITION` if both codes apply. +// Similarly prefer `NOT_FOUND` or `ALREADY_EXISTS` over `FAILED_PRECONDITION`. +enum Code { + // Not an error; returned on success. + // + // HTTP Mapping: 200 OK + OK = 0; + + // The operation was cancelled, typically by the caller. + // + // HTTP Mapping: 499 Client Closed Request + CANCELLED = 1; + + // Unknown error. For example, this error may be returned when + // a `Status` value received from another address space belongs to + // an error space that is not known in this address space. Also + // errors raised by APIs that do not return enough error information + // may be converted to this error. + // + // HTTP Mapping: 500 Internal Server Error + UNKNOWN = 2; + + // The client specified an invalid argument. Note that this differs + // from `FAILED_PRECONDITION`. `INVALID_ARGUMENT` indicates arguments + // that are problematic regardless of the state of the system + // (e.g., a malformed file name). + // + // HTTP Mapping: 400 Bad Request + INVALID_ARGUMENT = 3; + + // The deadline expired before the operation could complete. For operations + // that change the state of the system, this error may be returned + // even if the operation has completed successfully. For example, a + // successful response from a server could have been delayed long + // enough for the deadline to expire. + // + // HTTP Mapping: 504 Gateway Timeout + DEADLINE_EXCEEDED = 4; + + // Some requested entity (e.g., file or directory) was not found. + // + // Note to server developers: if a request is denied for an entire class + // of users, such as gradual feature rollout or undocumented allowlist, + // `NOT_FOUND` may be used. If a request is denied for some users within + // a class of users, such as user-based access control, `PERMISSION_DENIED` + // must be used. + // + // HTTP Mapping: 404 Not Found + NOT_FOUND = 5; + + // The entity that a client attempted to create (e.g., file or directory) + // already exists. + // + // HTTP Mapping: 409 Conflict + ALREADY_EXISTS = 6; + + // The caller does not have permission to execute the specified + // operation. `PERMISSION_DENIED` must not be used for rejections + // caused by exhausting some resource (use `RESOURCE_EXHAUSTED` + // instead for those errors). `PERMISSION_DENIED` must not be + // used if the caller can not be identified (use `UNAUTHENTICATED` + // instead for those errors). This error code does not imply the + // request is valid or the requested entity exists or satisfies + // other pre-conditions. + // + // HTTP Mapping: 403 Forbidden + PERMISSION_DENIED = 7; + + // The request does not have valid authentication credentials for the + // operation. + // + // HTTP Mapping: 401 Unauthorized + UNAUTHENTICATED = 16; + + // Some resource has been exhausted, perhaps a per-user quota, or + // perhaps the entire file system is out of space. + // + // HTTP Mapping: 429 Too Many Requests + RESOURCE_EXHAUSTED = 8; + + // The operation was rejected because the system is not in a state + // required for the operation's execution. For example, the directory + // to be deleted is non-empty, an rmdir operation is applied to + // a non-directory, etc. + // + // Service implementors can use the following guidelines to decide + // between `FAILED_PRECONDITION`, `ABORTED`, and `UNAVAILABLE`: + // (a) Use `UNAVAILABLE` if the client can retry just the failing call. + // (b) Use `ABORTED` if the client should retry at a higher level. For + // example, when a client-specified test-and-set fails, indicating the + // client should restart a read-modify-write sequence. + // (c) Use `FAILED_PRECONDITION` if the client should not retry until + // the system state has been explicitly fixed. For example, if an "rmdir" + // fails because the directory is non-empty, `FAILED_PRECONDITION` + // should be returned since the client should not retry unless + // the files are deleted from the directory. + // + // HTTP Mapping: 400 Bad Request + FAILED_PRECONDITION = 9; + + // The operation was aborted, typically due to a concurrency issue such as + // a sequencer check failure or transaction abort. + // + // See the guidelines above for deciding between `FAILED_PRECONDITION`, + // `ABORTED`, and `UNAVAILABLE`. + // + // HTTP Mapping: 409 Conflict + ABORTED = 10; + + // The operation was attempted past the valid range. E.g., seeking or + // reading past end-of-file. + // + // Unlike `INVALID_ARGUMENT`, this error indicates a problem that may + // be fixed if the system state changes. For example, a 32-bit file + // system will generate `INVALID_ARGUMENT` if asked to read at an + // offset that is not in the range [0,2^32-1], but it will generate + // `OUT_OF_RANGE` if asked to read from an offset past the current + // file size. + // + // There is a fair bit of overlap between `FAILED_PRECONDITION` and + // `OUT_OF_RANGE`. We recommend using `OUT_OF_RANGE` (the more specific + // error) when it applies so that callers who are iterating through + // a space can easily look for an `OUT_OF_RANGE` error to detect when + // they are done. + // + // HTTP Mapping: 400 Bad Request + OUT_OF_RANGE = 11; + + // The operation is not implemented or is not supported/enabled in this + // service. + // + // HTTP Mapping: 501 Not Implemented + UNIMPLEMENTED = 12; + + // Internal errors. This means that some invariants expected by the + // underlying system have been broken. This error code is reserved + // for serious errors. + // + // HTTP Mapping: 500 Internal Server Error + INTERNAL = 13; + + // The service is currently unavailable. This is most likely a + // transient condition, which can be corrected by retrying with + // a backoff. Note that it is not always safe to retry + // non-idempotent operations. + // + // See the guidelines above for deciding between `FAILED_PRECONDITION`, + // `ABORTED`, and `UNAVAILABLE`. + // + // HTTP Mapping: 503 Service Unavailable + UNAVAILABLE = 14; + + // Unrecoverable data loss or corruption. + // + // HTTP Mapping: 500 Internal Server Error + DATA_LOSS = 15; +} diff --git a/.venv/lib/python3.11/site-packages/google/rpc/context/__pycache__/attribute_context_pb2.cpython-311.pyc b/.venv/lib/python3.11/site-packages/google/rpc/context/__pycache__/attribute_context_pb2.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e4f6e3940a062b3aae10c8b028560b26444f854a Binary files /dev/null and b/.venv/lib/python3.11/site-packages/google/rpc/context/__pycache__/attribute_context_pb2.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/google/rpc/context/audit_context.proto b/.venv/lib/python3.11/site-packages/google/rpc/context/audit_context.proto new file mode 100644 index 0000000000000000000000000000000000000000..74945cd46bbfbc7c288d56fa20d3599c0873b5f7 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/google/rpc/context/audit_context.proto @@ -0,0 +1,49 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto3"; + +package google.rpc.context; + +import "google/protobuf/struct.proto"; + +option cc_enable_arenas = true; +option go_package = "google.golang.org/genproto/googleapis/rpc/context;context"; +option java_multiple_files = true; +option java_outer_classname = "AuditContextProto"; +option java_package = "com.google.rpc.context"; + +// `AuditContext` provides information that is needed for audit logging. +message AuditContext { + // Serialized audit log. + bytes audit_log = 1; + + // An API request message that is scrubbed based on the method annotation. + // This field should only be filled if audit_log field is present. + // Service Control will use this to assemble a complete log for Cloud Audit + // Logs and Google internal audit logs. + google.protobuf.Struct scrubbed_request = 2; + + // An API response message that is scrubbed based on the method annotation. + // This field should only be filled if audit_log field is present. + // Service Control will use this to assemble a complete log for Cloud Audit + // Logs and Google internal audit logs. + google.protobuf.Struct scrubbed_response = 3; + + // Number of scrubbed response items. + int32 scrubbed_response_item_count = 4; + + // Audit resource name which is scrubbed. + string target_resource = 5; +} diff --git a/.venv/lib/python3.11/site-packages/google/rpc/http_pb2.py b/.venv/lib/python3.11/site-packages/google/rpc/http_pb2.py new file mode 100644 index 0000000000000000000000000000000000000000..385b0dd606ad880f4a2c1e13b782a6cc8b2e1bb1 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/google/rpc/http_pb2.py @@ -0,0 +1,46 @@ +# -*- coding: utf-8 -*- + +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: google/rpc/http.proto +"""Generated protocol buffer code.""" +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import symbol_database as _symbol_database +from google.protobuf.internal import builder as _builder + +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( + b'\n\x15google/rpc/http.proto\x12\ngoogle.rpc"a\n\x0bHttpRequest\x12\x0e\n\x06method\x18\x01 \x01(\t\x12\x0b\n\x03uri\x18\x02 \x01(\t\x12\'\n\x07headers\x18\x03 \x03(\x0b\x32\x16.google.rpc.HttpHeader\x12\x0c\n\x04\x62ody\x18\x04 \x01(\x0c"e\n\x0cHttpResponse\x12\x0e\n\x06status\x18\x01 \x01(\x05\x12\x0e\n\x06reason\x18\x02 \x01(\t\x12\'\n\x07headers\x18\x03 \x03(\x0b\x32\x16.google.rpc.HttpHeader\x12\x0c\n\x04\x62ody\x18\x04 \x01(\x0c"(\n\nHttpHeader\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\tBX\n\x0e\x63om.google.rpcB\tHttpProtoP\x01Z3google.golang.org/genproto/googleapis/rpc/http;http\xa2\x02\x03RPCb\x06proto3' +) + +_globals = globals() +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "google.rpc.http_pb2", _globals) +if _descriptor._USE_C_DESCRIPTORS == False: + DESCRIPTOR._options = None + DESCRIPTOR._serialized_options = b"\n\016com.google.rpcB\tHttpProtoP\001Z3google.golang.org/genproto/googleapis/rpc/http;http\242\002\003RPC" + _globals["_HTTPREQUEST"]._serialized_start = 37 + _globals["_HTTPREQUEST"]._serialized_end = 134 + _globals["_HTTPRESPONSE"]._serialized_start = 136 + _globals["_HTTPRESPONSE"]._serialized_end = 237 + _globals["_HTTPHEADER"]._serialized_start = 239 + _globals["_HTTPHEADER"]._serialized_end = 279 +# @@protoc_insertion_point(module_scope) diff --git a/.venv/lib/python3.11/site-packages/google/rpc/status.proto b/.venv/lib/python3.11/site-packages/google/rpc/status.proto new file mode 100644 index 0000000000000000000000000000000000000000..90b70ddf9178d2ff904464359312bb32948360a4 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/google/rpc/status.proto @@ -0,0 +1,49 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto3"; + +package google.rpc; + +import "google/protobuf/any.proto"; + +option cc_enable_arenas = true; +option go_package = "google.golang.org/genproto/googleapis/rpc/status;status"; +option java_multiple_files = true; +option java_outer_classname = "StatusProto"; +option java_package = "com.google.rpc"; +option objc_class_prefix = "RPC"; + +// The `Status` type defines a logical error model that is suitable for +// different programming environments, including REST APIs and RPC APIs. It is +// used by [gRPC](https://github.com/grpc). Each `Status` message contains +// three pieces of data: error code, error message, and error details. +// +// You can find out more about this error model and how to work with it in the +// [API Design Guide](https://cloud.google.com/apis/design/errors). +message Status { + // The status code, which should be an enum value of + // [google.rpc.Code][google.rpc.Code]. + int32 code = 1; + + // A developer-facing error message, which should be in English. Any + // user-facing error message should be localized and sent in the + // [google.rpc.Status.details][google.rpc.Status.details] field, or localized + // by the client. + string message = 2; + + // A list of messages that carry the error details. There is a common set of + // message types for APIs to use. + repeated google.protobuf.Any details = 3; +}