Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .venv/lib/python3.11/site-packages/google/cloud/__pycache__/extended_operations_pb2.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/google/cloud/extended_operations.proto +150 -0
- .venv/lib/python3.11/site-packages/google/cloud/extended_operations_pb2.py +47 -0
- .venv/lib/python3.11/site-packages/google/cloud/location/__pycache__/locations_pb2.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/google/cloud/location/locations.proto +108 -0
- .venv/lib/python3.11/site-packages/google/cloud/location/locations_pb2.py +71 -0
- .venv/lib/python3.11/site-packages/google/generativeai/__init__.py +86 -0
- .venv/lib/python3.11/site-packages/google/generativeai/answer.py +363 -0
- .venv/lib/python3.11/site-packages/google/generativeai/audio_models/__init__.py +0 -0
- .venv/lib/python3.11/site-packages/google/generativeai/audio_models/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/google/generativeai/audio_models/__pycache__/_audio_models.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/google/generativeai/audio_models/_audio_models.py +3 -0
- .venv/lib/python3.11/site-packages/google/generativeai/caching.py +314 -0
- .venv/lib/python3.11/site-packages/google/generativeai/client.py +388 -0
- .venv/lib/python3.11/site-packages/google/generativeai/discuss.py +590 -0
- .venv/lib/python3.11/site-packages/google/generativeai/embedding.py +312 -0
- .venv/lib/python3.11/site-packages/google/generativeai/files.py +116 -0
- .venv/lib/python3.11/site-packages/google/generativeai/generative_models.py +875 -0
- .venv/lib/python3.11/site-packages/google/generativeai/models.py +467 -0
- .venv/lib/python3.11/site-packages/google/generativeai/notebook/__init__.py +32 -0
- .venv/lib/python3.11/site-packages/google/generativeai/notebook/__pycache__/command.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/google/generativeai/notebook/__pycache__/compare_cmd.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/google/generativeai/notebook/__pycache__/eval_cmd.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/google/generativeai/notebook/__pycache__/flag_def.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/google/generativeai/notebook/__pycache__/input_utils.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/google/generativeai/notebook/__pycache__/magics_engine.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/google/generativeai/notebook/__pycache__/parsed_args_lib.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/google/generativeai/notebook/__pycache__/run_cmd.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/google/generativeai/notebook/__pycache__/sheets_id.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/google/generativeai/notebook/__pycache__/sheets_utils.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/google/generativeai/notebook/cmd_line_parser.py +557 -0
- .venv/lib/python3.11/site-packages/google/generativeai/notebook/command.py +45 -0
- .venv/lib/python3.11/site-packages/google/generativeai/notebook/compare_cmd.py +71 -0
- .venv/lib/python3.11/site-packages/google/generativeai/notebook/compile_cmd.py +66 -0
- .venv/lib/python3.11/site-packages/google/generativeai/notebook/eval_cmd.py +75 -0
- .venv/lib/python3.11/site-packages/google/generativeai/notebook/flag_def.py +503 -0
- .venv/lib/python3.11/site-packages/google/generativeai/notebook/gspread_client.py +232 -0
- .venv/lib/python3.11/site-packages/google/generativeai/notebook/html_utils.py +46 -0
- .venv/lib/python3.11/site-packages/google/generativeai/notebook/input_utils.py +73 -0
- .venv/lib/python3.11/site-packages/google/generativeai/notebook/ipython_env_impl.py +30 -0
- .venv/lib/python3.11/site-packages/google/generativeai/notebook/magics.py +150 -0
- .venv/lib/python3.11/site-packages/google/generativeai/notebook/magics_engine.py +124 -0
- .venv/lib/python3.11/site-packages/google/generativeai/notebook/model_registry.py +55 -0
- .venv/lib/python3.11/site-packages/google/generativeai/notebook/parsed_args_lib.py +85 -0
- .venv/lib/python3.11/site-packages/google/generativeai/notebook/post_process_utils_test_helper.py +32 -0
- .venv/lib/python3.11/site-packages/google/generativeai/notebook/run_cmd.py +75 -0
- .venv/lib/python3.11/site-packages/google/generativeai/notebook/sheets_id.py +101 -0
- .venv/lib/python3.11/site-packages/google/generativeai/notebook/sheets_sanitize_url.py +89 -0
- .venv/lib/python3.11/site-packages/google/generativeai/notebook/sheets_utils.py +111 -0
- .venv/lib/python3.11/site-packages/google/generativeai/notebook/text_model.py +72 -0
.venv/lib/python3.11/site-packages/google/cloud/__pycache__/extended_operations_pb2.cpython-311.pyc
ADDED
|
Binary file (2.24 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/google/cloud/extended_operations.proto
ADDED
|
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright 2024 Google LLC
|
| 2 |
+
//
|
| 3 |
+
// Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
// you may not use this file except in compliance with the License.
|
| 5 |
+
// You may obtain a copy of the License at
|
| 6 |
+
//
|
| 7 |
+
// http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
//
|
| 9 |
+
// Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
// distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
// See the License for the specific language governing permissions and
|
| 13 |
+
// limitations under the License.
|
| 14 |
+
|
| 15 |
+
// This file contains custom annotations that are used by GAPIC generators to
|
| 16 |
+
// handle Long Running Operation methods (LRO) that are NOT compliant with
|
| 17 |
+
// https://google.aip.dev/151. These annotations are public for technical
|
| 18 |
+
// reasons only. Please DO NOT USE them in your protos.
|
| 19 |
+
syntax = "proto3";
|
| 20 |
+
|
| 21 |
+
package google.cloud;
|
| 22 |
+
|
| 23 |
+
import "google/protobuf/descriptor.proto";
|
| 24 |
+
|
| 25 |
+
option go_package = "google.golang.org/genproto/googleapis/cloud/extendedops;extendedops";
|
| 26 |
+
option java_multiple_files = true;
|
| 27 |
+
option java_outer_classname = "ExtendedOperationsProto";
|
| 28 |
+
option java_package = "com.google.cloud";
|
| 29 |
+
option objc_class_prefix = "GAPI";
|
| 30 |
+
|
| 31 |
+
// FieldOptions to match corresponding fields in the initial request,
|
| 32 |
+
// polling request and operation response messages.
|
| 33 |
+
//
|
| 34 |
+
// Example:
|
| 35 |
+
//
|
| 36 |
+
// In an API-specific operation message:
|
| 37 |
+
//
|
| 38 |
+
// message MyOperation {
|
| 39 |
+
// string http_error_message = 1 [(operation_field) = ERROR_MESSAGE];
|
| 40 |
+
// int32 http_error_status_code = 2 [(operation_field) = ERROR_CODE];
|
| 41 |
+
// string id = 3 [(operation_field) = NAME];
|
| 42 |
+
// Status status = 4 [(operation_field) = STATUS];
|
| 43 |
+
// }
|
| 44 |
+
//
|
| 45 |
+
// In a polling request message (the one which is used to poll for an LRO
|
| 46 |
+
// status):
|
| 47 |
+
//
|
| 48 |
+
// message MyPollingRequest {
|
| 49 |
+
// string operation = 1 [(operation_response_field) = "id"];
|
| 50 |
+
// string project = 2;
|
| 51 |
+
// string region = 3;
|
| 52 |
+
// }
|
| 53 |
+
//
|
| 54 |
+
// In an initial request message (the one which starts an LRO):
|
| 55 |
+
//
|
| 56 |
+
// message MyInitialRequest {
|
| 57 |
+
// string my_project = 2 [(operation_request_field) = "project"];
|
| 58 |
+
// string my_region = 3 [(operation_request_field) = "region"];
|
| 59 |
+
// }
|
| 60 |
+
//
|
| 61 |
+
extend google.protobuf.FieldOptions {
|
| 62 |
+
// A field annotation that maps fields in an API-specific Operation object to
|
| 63 |
+
// their standard counterparts in google.longrunning.Operation. See
|
| 64 |
+
// OperationResponseMapping enum definition.
|
| 65 |
+
OperationResponseMapping operation_field = 1149;
|
| 66 |
+
|
| 67 |
+
// A field annotation that maps fields in the initial request message
|
| 68 |
+
// (the one which started the LRO) to their counterparts in the polling
|
| 69 |
+
// request message. For non-standard LRO, the polling response may be missing
|
| 70 |
+
// some of the information needed to make a subsequent polling request. The
|
| 71 |
+
// missing information (for example, project or region ID) is contained in the
|
| 72 |
+
// fields of the initial request message that this annotation must be applied
|
| 73 |
+
// to. The string value of the annotation corresponds to the name of the
|
| 74 |
+
// counterpart field in the polling request message that the annotated field's
|
| 75 |
+
// value will be copied to.
|
| 76 |
+
string operation_request_field = 1150;
|
| 77 |
+
|
| 78 |
+
// A field annotation that maps fields in the polling request message to their
|
| 79 |
+
// counterparts in the initial and/or polling response message. The initial
|
| 80 |
+
// and the polling methods return an API-specific Operation object. Some of
|
| 81 |
+
// the fields from that response object must be reused in the subsequent
|
| 82 |
+
// request (like operation name/ID) to fully identify the polled operation.
|
| 83 |
+
// This annotation must be applied to the fields in the polling request
|
| 84 |
+
// message, the string value of the annotation must correspond to the name of
|
| 85 |
+
// the counterpart field in the Operation response object whose value will be
|
| 86 |
+
// copied to the annotated field.
|
| 87 |
+
string operation_response_field = 1151;
|
| 88 |
+
}
|
| 89 |
+
|
| 90 |
+
// MethodOptions to identify the actual service and method used for operation
|
| 91 |
+
// status polling.
|
| 92 |
+
//
|
| 93 |
+
// Example:
|
| 94 |
+
//
|
| 95 |
+
// In a method, which starts an LRO:
|
| 96 |
+
//
|
| 97 |
+
// service MyService {
|
| 98 |
+
// rpc Foo(MyInitialRequest) returns (MyOperation) {
|
| 99 |
+
// option (operation_service) = "MyPollingService";
|
| 100 |
+
// }
|
| 101 |
+
// }
|
| 102 |
+
//
|
| 103 |
+
// In a polling method:
|
| 104 |
+
//
|
| 105 |
+
// service MyPollingService {
|
| 106 |
+
// rpc Get(MyPollingRequest) returns (MyOperation) {
|
| 107 |
+
// option (operation_polling_method) = true;
|
| 108 |
+
// }
|
| 109 |
+
// }
|
| 110 |
+
extend google.protobuf.MethodOptions {
|
| 111 |
+
// A method annotation that maps an LRO method (the one which starts an LRO)
|
| 112 |
+
// to the service, which will be used to poll for the operation status. The
|
| 113 |
+
// annotation must be applied to the method which starts an LRO, the string
|
| 114 |
+
// value of the annotation must correspond to the name of the service used to
|
| 115 |
+
// poll for the operation status.
|
| 116 |
+
string operation_service = 1249;
|
| 117 |
+
|
| 118 |
+
// A method annotation that marks methods that can be used for polling
|
| 119 |
+
// operation status (e.g. the MyPollingService.Get(MyPollingRequest) method).
|
| 120 |
+
bool operation_polling_method = 1250;
|
| 121 |
+
}
|
| 122 |
+
|
| 123 |
+
// An enum to be used to mark the essential (for polling) fields in an
|
| 124 |
+
// API-specific Operation object. A custom Operation object may contain many
|
| 125 |
+
// different fields, but only few of them are essential to conduct a successful
|
| 126 |
+
// polling process.
|
| 127 |
+
enum OperationResponseMapping {
|
| 128 |
+
// Do not use.
|
| 129 |
+
UNDEFINED = 0;
|
| 130 |
+
|
| 131 |
+
// A field in an API-specific (custom) Operation object which carries the same
|
| 132 |
+
// meaning as google.longrunning.Operation.name.
|
| 133 |
+
NAME = 1;
|
| 134 |
+
|
| 135 |
+
// A field in an API-specific (custom) Operation object which carries the same
|
| 136 |
+
// meaning as google.longrunning.Operation.done. If the annotated field is of
|
| 137 |
+
// an enum type, `annotated_field_name == EnumType.DONE` semantics should be
|
| 138 |
+
// equivalent to `Operation.done == true`. If the annotated field is of type
|
| 139 |
+
// boolean, then it should follow the same semantics as Operation.done.
|
| 140 |
+
// Otherwise, a non-empty value should be treated as `Operation.done == true`.
|
| 141 |
+
STATUS = 2;
|
| 142 |
+
|
| 143 |
+
// A field in an API-specific (custom) Operation object which carries the same
|
| 144 |
+
// meaning as google.longrunning.Operation.error.code.
|
| 145 |
+
ERROR_CODE = 3;
|
| 146 |
+
|
| 147 |
+
// A field in an API-specific (custom) Operation object which carries the same
|
| 148 |
+
// meaning as google.longrunning.Operation.error.message.
|
| 149 |
+
ERROR_MESSAGE = 4;
|
| 150 |
+
}
|
.venv/lib/python3.11/site-packages/google/cloud/extended_operations_pb2.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
|
| 3 |
+
# Copyright 2024 Google LLC
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
|
| 17 |
+
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
| 18 |
+
# source: google/cloud/extended_operations.proto
|
| 19 |
+
"""Generated protocol buffer code."""
|
| 20 |
+
from google.protobuf import descriptor as _descriptor
|
| 21 |
+
from google.protobuf import descriptor_pool as _descriptor_pool
|
| 22 |
+
from google.protobuf import symbol_database as _symbol_database
|
| 23 |
+
from google.protobuf.internal import builder as _builder
|
| 24 |
+
|
| 25 |
+
# @@protoc_insertion_point(imports)
|
| 26 |
+
|
| 27 |
+
_sym_db = _symbol_database.Default()
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
from google.protobuf import descriptor_pb2 as google_dot_protobuf_dot_descriptor__pb2
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
|
| 34 |
+
b"\n&google/cloud/extended_operations.proto\x12\x0cgoogle.cloud\x1a google/protobuf/descriptor.proto*b\n\x18OperationResponseMapping\x12\r\n\tUNDEFINED\x10\x00\x12\x08\n\x04NAME\x10\x01\x12\n\n\x06STATUS\x10\x02\x12\x0e\n\nERROR_CODE\x10\x03\x12\x11\n\rERROR_MESSAGE\x10\x04:_\n\x0foperation_field\x12\x1d.google.protobuf.FieldOptions\x18\xfd\x08 \x01(\x0e\x32&.google.cloud.OperationResponseMapping:?\n\x17operation_request_field\x12\x1d.google.protobuf.FieldOptions\x18\xfe\x08 \x01(\t:@\n\x18operation_response_field\x12\x1d.google.protobuf.FieldOptions\x18\xff\x08 \x01(\t::\n\x11operation_service\x12\x1e.google.protobuf.MethodOptions\x18\xe1\t \x01(\t:A\n\x18operation_polling_method\x12\x1e.google.protobuf.MethodOptions\x18\xe2\t \x01(\x08\x42y\n\x10\x63om.google.cloudB\x17\x45xtendedOperationsProtoP\x01ZCgoogle.golang.org/genproto/googleapis/cloud/extendedops;extendedops\xa2\x02\x04GAPIb\x06proto3"
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
_globals = globals()
|
| 38 |
+
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
| 39 |
+
_builder.BuildTopDescriptorsAndMessages(
|
| 40 |
+
DESCRIPTOR, "google.cloud.extended_operations_pb2", _globals
|
| 41 |
+
)
|
| 42 |
+
if _descriptor._USE_C_DESCRIPTORS == False:
|
| 43 |
+
DESCRIPTOR._options = None
|
| 44 |
+
DESCRIPTOR._serialized_options = b"\n\020com.google.cloudB\027ExtendedOperationsProtoP\001ZCgoogle.golang.org/genproto/googleapis/cloud/extendedops;extendedops\242\002\004GAPI"
|
| 45 |
+
_globals["_OPERATIONRESPONSEMAPPING"]._serialized_start = 90
|
| 46 |
+
_globals["_OPERATIONRESPONSEMAPPING"]._serialized_end = 188
|
| 47 |
+
# @@protoc_insertion_point(module_scope)
|
.venv/lib/python3.11/site-packages/google/cloud/location/__pycache__/locations_pb2.cpython-311.pyc
ADDED
|
Binary file (3.99 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/google/cloud/location/locations.proto
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright 2024 Google LLC
|
| 2 |
+
//
|
| 3 |
+
// Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
// you may not use this file except in compliance with the License.
|
| 5 |
+
// You may obtain a copy of the License at
|
| 6 |
+
//
|
| 7 |
+
// http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
//
|
| 9 |
+
// Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
// distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
// See the License for the specific language governing permissions and
|
| 13 |
+
// limitations under the License.
|
| 14 |
+
|
| 15 |
+
syntax = "proto3";
|
| 16 |
+
|
| 17 |
+
package google.cloud.location;
|
| 18 |
+
|
| 19 |
+
import "google/api/annotations.proto";
|
| 20 |
+
import "google/protobuf/any.proto";
|
| 21 |
+
import "google/api/client.proto";
|
| 22 |
+
|
| 23 |
+
option cc_enable_arenas = true;
|
| 24 |
+
option go_package = "google.golang.org/genproto/googleapis/cloud/location;location";
|
| 25 |
+
option java_multiple_files = true;
|
| 26 |
+
option java_outer_classname = "LocationsProto";
|
| 27 |
+
option java_package = "com.google.cloud.location";
|
| 28 |
+
|
| 29 |
+
// An abstract interface that provides location-related information for
|
| 30 |
+
// a service. Service-specific metadata is provided through the
|
| 31 |
+
// [Location.metadata][google.cloud.location.Location.metadata] field.
|
| 32 |
+
service Locations {
|
| 33 |
+
option (google.api.default_host) = "cloud.googleapis.com";
|
| 34 |
+
option (google.api.oauth_scopes) = "https://www.googleapis.com/auth/cloud-platform";
|
| 35 |
+
|
| 36 |
+
// Lists information about the supported locations for this service.
|
| 37 |
+
rpc ListLocations(ListLocationsRequest) returns (ListLocationsResponse) {
|
| 38 |
+
option (google.api.http) = {
|
| 39 |
+
get: "/v1/{name=locations}"
|
| 40 |
+
additional_bindings {
|
| 41 |
+
get: "/v1/{name=projects/*}/locations"
|
| 42 |
+
}
|
| 43 |
+
};
|
| 44 |
+
}
|
| 45 |
+
|
| 46 |
+
// Gets information about a location.
|
| 47 |
+
rpc GetLocation(GetLocationRequest) returns (Location) {
|
| 48 |
+
option (google.api.http) = {
|
| 49 |
+
get: "/v1/{name=locations/*}"
|
| 50 |
+
additional_bindings {
|
| 51 |
+
get: "/v1/{name=projects/*/locations/*}"
|
| 52 |
+
}
|
| 53 |
+
};
|
| 54 |
+
}
|
| 55 |
+
}
|
| 56 |
+
|
| 57 |
+
// The request message for [Locations.ListLocations][google.cloud.location.Locations.ListLocations].
|
| 58 |
+
message ListLocationsRequest {
|
| 59 |
+
// The resource that owns the locations collection, if applicable.
|
| 60 |
+
string name = 1;
|
| 61 |
+
|
| 62 |
+
// The standard list filter.
|
| 63 |
+
string filter = 2;
|
| 64 |
+
|
| 65 |
+
// The standard list page size.
|
| 66 |
+
int32 page_size = 3;
|
| 67 |
+
|
| 68 |
+
// The standard list page token.
|
| 69 |
+
string page_token = 4;
|
| 70 |
+
}
|
| 71 |
+
|
| 72 |
+
// The response message for [Locations.ListLocations][google.cloud.location.Locations.ListLocations].
|
| 73 |
+
message ListLocationsResponse {
|
| 74 |
+
// A list of locations that matches the specified filter in the request.
|
| 75 |
+
repeated Location locations = 1;
|
| 76 |
+
|
| 77 |
+
// The standard List next-page token.
|
| 78 |
+
string next_page_token = 2;
|
| 79 |
+
}
|
| 80 |
+
|
| 81 |
+
// The request message for [Locations.GetLocation][google.cloud.location.Locations.GetLocation].
|
| 82 |
+
message GetLocationRequest {
|
| 83 |
+
// Resource name for the location.
|
| 84 |
+
string name = 1;
|
| 85 |
+
}
|
| 86 |
+
|
| 87 |
+
// A resource that represents Google Cloud Platform location.
|
| 88 |
+
message Location {
|
| 89 |
+
// Resource name for the location, which may vary between implementations.
|
| 90 |
+
// For example: `"projects/example-project/locations/us-east1"`
|
| 91 |
+
string name = 1;
|
| 92 |
+
|
| 93 |
+
// The canonical id for this location. For example: `"us-east1"`.
|
| 94 |
+
string location_id = 4;
|
| 95 |
+
|
| 96 |
+
// The friendly name for this location, typically a nearby city name.
|
| 97 |
+
// For example, "Tokyo".
|
| 98 |
+
string display_name = 5;
|
| 99 |
+
|
| 100 |
+
// Cross-service attributes for the location. For example
|
| 101 |
+
//
|
| 102 |
+
// {"cloud.googleapis.com/region": "us-east1"}
|
| 103 |
+
map<string, string> labels = 2;
|
| 104 |
+
|
| 105 |
+
// Service-specific metadata. For example the available capacity at the given
|
| 106 |
+
// location.
|
| 107 |
+
google.protobuf.Any metadata = 3;
|
| 108 |
+
}
|
.venv/lib/python3.11/site-packages/google/cloud/location/locations_pb2.py
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
|
| 3 |
+
# Copyright 2024 Google LLC
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
|
| 17 |
+
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
| 18 |
+
# source: google/cloud/location/locations.proto
|
| 19 |
+
"""Generated protocol buffer code."""
|
| 20 |
+
from google.protobuf import descriptor as _descriptor
|
| 21 |
+
from google.protobuf import descriptor_pool as _descriptor_pool
|
| 22 |
+
from google.protobuf import symbol_database as _symbol_database
|
| 23 |
+
from google.protobuf.internal import builder as _builder
|
| 24 |
+
|
| 25 |
+
# @@protoc_insertion_point(imports)
|
| 26 |
+
|
| 27 |
+
_sym_db = _symbol_database.Default()
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
from google.api import annotations_pb2 as google_dot_api_dot_annotations__pb2
|
| 31 |
+
from google.protobuf import any_pb2 as google_dot_protobuf_dot_any__pb2
|
| 32 |
+
from google.api import client_pb2 as google_dot_api_dot_client__pb2
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
|
| 36 |
+
b'\n%google/cloud/location/locations.proto\x12\x15google.cloud.location\x1a\x1cgoogle/api/annotations.proto\x1a\x19google/protobuf/any.proto\x1a\x17google/api/client.proto"[\n\x14ListLocationsRequest\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x0e\n\x06\x66ilter\x18\x02 \x01(\t\x12\x11\n\tpage_size\x18\x03 \x01(\x05\x12\x12\n\npage_token\x18\x04 \x01(\t"d\n\x15ListLocationsResponse\x12\x32\n\tlocations\x18\x01 \x03(\x0b\x32\x1f.google.cloud.location.Location\x12\x17\n\x0fnext_page_token\x18\x02 \x01(\t""\n\x12GetLocationRequest\x12\x0c\n\x04name\x18\x01 \x01(\t"\xd7\x01\n\x08Location\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x13\n\x0blocation_id\x18\x04 \x01(\t\x12\x14\n\x0c\x64isplay_name\x18\x05 \x01(\t\x12;\n\x06labels\x18\x02 \x03(\x0b\x32+.google.cloud.location.Location.LabelsEntry\x12&\n\x08metadata\x18\x03 \x01(\x0b\x32\x14.google.protobuf.Any\x1a-\n\x0bLabelsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\x32\xa4\x03\n\tLocations\x12\xab\x01\n\rListLocations\x12+.google.cloud.location.ListLocationsRequest\x1a,.google.cloud.location.ListLocationsResponse"?\x82\xd3\xe4\x93\x02\x39\x12\x14/v1/{name=locations}Z!\x12\x1f/v1/{name=projects/*}/locations\x12\x9e\x01\n\x0bGetLocation\x12).google.cloud.location.GetLocationRequest\x1a\x1f.google.cloud.location.Location"C\x82\xd3\xe4\x93\x02=\x12\x16/v1/{name=locations/*}Z#\x12!/v1/{name=projects/*/locations/*}\x1aH\xca\x41\x14\x63loud.googleapis.com\xd2\x41.https://www.googleapis.com/auth/cloud-platformBo\n\x19\x63om.google.cloud.locationB\x0eLocationsProtoP\x01Z=google.golang.org/genproto/googleapis/cloud/location;location\xf8\x01\x01\x62\x06proto3'
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
_globals = globals()
|
| 40 |
+
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
| 41 |
+
_builder.BuildTopDescriptorsAndMessages(
|
| 42 |
+
DESCRIPTOR, "google.cloud.location.locations_pb2", _globals
|
| 43 |
+
)
|
| 44 |
+
if _descriptor._USE_C_DESCRIPTORS == False:
|
| 45 |
+
DESCRIPTOR._options = None
|
| 46 |
+
DESCRIPTOR._serialized_options = b"\n\031com.google.cloud.locationB\016LocationsProtoP\001Z=google.golang.org/genproto/googleapis/cloud/location;location\370\001\001"
|
| 47 |
+
_LOCATION_LABELSENTRY._options = None
|
| 48 |
+
_LOCATION_LABELSENTRY._serialized_options = b"8\001"
|
| 49 |
+
_LOCATIONS._options = None
|
| 50 |
+
_LOCATIONS._serialized_options = b"\312A\024cloud.googleapis.com\322A.https://www.googleapis.com/auth/cloud-platform"
|
| 51 |
+
_LOCATIONS.methods_by_name["ListLocations"]._options = None
|
| 52 |
+
_LOCATIONS.methods_by_name[
|
| 53 |
+
"ListLocations"
|
| 54 |
+
]._serialized_options = b"\202\323\344\223\0029\022\024/v1/{name=locations}Z!\022\037/v1/{name=projects/*}/locations"
|
| 55 |
+
_LOCATIONS.methods_by_name["GetLocation"]._options = None
|
| 56 |
+
_LOCATIONS.methods_by_name[
|
| 57 |
+
"GetLocation"
|
| 58 |
+
]._serialized_options = b"\202\323\344\223\002=\022\026/v1/{name=locations/*}Z#\022!/v1/{name=projects/*/locations/*}"
|
| 59 |
+
_globals["_LISTLOCATIONSREQUEST"]._serialized_start = 146
|
| 60 |
+
_globals["_LISTLOCATIONSREQUEST"]._serialized_end = 237
|
| 61 |
+
_globals["_LISTLOCATIONSRESPONSE"]._serialized_start = 239
|
| 62 |
+
_globals["_LISTLOCATIONSRESPONSE"]._serialized_end = 339
|
| 63 |
+
_globals["_GETLOCATIONREQUEST"]._serialized_start = 341
|
| 64 |
+
_globals["_GETLOCATIONREQUEST"]._serialized_end = 375
|
| 65 |
+
_globals["_LOCATION"]._serialized_start = 378
|
| 66 |
+
_globals["_LOCATION"]._serialized_end = 593
|
| 67 |
+
_globals["_LOCATION_LABELSENTRY"]._serialized_start = 548
|
| 68 |
+
_globals["_LOCATION_LABELSENTRY"]._serialized_end = 593
|
| 69 |
+
_globals["_LOCATIONS"]._serialized_start = 596
|
| 70 |
+
_globals["_LOCATIONS"]._serialized_end = 1016
|
| 71 |
+
# @@protoc_insertion_point(module_scope)
|
.venv/lib/python3.11/site-packages/google/generativeai/__init__.py
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Copyright 2023 Google LLC
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Google AI Python SDK
|
| 16 |
+
|
| 17 |
+
## Setup
|
| 18 |
+
|
| 19 |
+
```posix-terminal
|
| 20 |
+
pip install google-generativeai
|
| 21 |
+
```
|
| 22 |
+
|
| 23 |
+
## GenerativeModel
|
| 24 |
+
|
| 25 |
+
Use `genai.GenerativeModel` to access the API:
|
| 26 |
+
|
| 27 |
+
```
|
| 28 |
+
import google.generativeai as genai
|
| 29 |
+
import os
|
| 30 |
+
|
| 31 |
+
genai.configure(api_key=os.environ['API_KEY'])
|
| 32 |
+
|
| 33 |
+
model = genai.GenerativeModel(model_name='gemini-1.5-flash')
|
| 34 |
+
response = model.generate_content('Teach me about how an LLM works')
|
| 35 |
+
|
| 36 |
+
print(response.text)
|
| 37 |
+
```
|
| 38 |
+
|
| 39 |
+
See the [python quickstart](https://ai.google.dev/tutorials/python_quickstart) for more details.
|
| 40 |
+
"""
|
| 41 |
+
from __future__ import annotations
|
| 42 |
+
|
| 43 |
+
from google.generativeai import version
|
| 44 |
+
|
| 45 |
+
from google.generativeai import caching
|
| 46 |
+
from google.generativeai import protos
|
| 47 |
+
from google.generativeai import types
|
| 48 |
+
|
| 49 |
+
from google.generativeai.client import configure
|
| 50 |
+
|
| 51 |
+
from google.generativeai.embedding import embed_content
|
| 52 |
+
from google.generativeai.embedding import embed_content_async
|
| 53 |
+
|
| 54 |
+
from google.generativeai.files import upload_file
|
| 55 |
+
from google.generativeai.files import get_file
|
| 56 |
+
from google.generativeai.files import list_files
|
| 57 |
+
from google.generativeai.files import delete_file
|
| 58 |
+
|
| 59 |
+
from google.generativeai.generative_models import GenerativeModel
|
| 60 |
+
from google.generativeai.generative_models import ChatSession
|
| 61 |
+
|
| 62 |
+
from google.generativeai.models import list_models
|
| 63 |
+
from google.generativeai.models import list_tuned_models
|
| 64 |
+
|
| 65 |
+
from google.generativeai.models import get_model
|
| 66 |
+
from google.generativeai.models import get_base_model
|
| 67 |
+
from google.generativeai.models import get_tuned_model
|
| 68 |
+
|
| 69 |
+
from google.generativeai.models import create_tuned_model
|
| 70 |
+
from google.generativeai.models import update_tuned_model
|
| 71 |
+
from google.generativeai.models import delete_tuned_model
|
| 72 |
+
|
| 73 |
+
from google.generativeai.operations import list_operations
|
| 74 |
+
from google.generativeai.operations import get_operation
|
| 75 |
+
|
| 76 |
+
from google.generativeai.types import GenerationConfig
|
| 77 |
+
|
| 78 |
+
__version__ = version.__version__
|
| 79 |
+
|
| 80 |
+
del embedding
|
| 81 |
+
del files
|
| 82 |
+
del generative_models
|
| 83 |
+
del models
|
| 84 |
+
del client
|
| 85 |
+
del operations
|
| 86 |
+
del version
|
.venv/lib/python3.11/site-packages/google/generativeai/answer.py
ADDED
|
@@ -0,0 +1,363 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Copyright 2023 Google LLC
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
from __future__ import annotations
|
| 16 |
+
|
| 17 |
+
import dataclasses
|
| 18 |
+
from collections.abc import Iterable
|
| 19 |
+
import itertools
|
| 20 |
+
from typing import Any, Iterable, Union, Mapping, Optional
|
| 21 |
+
from typing_extensions import TypedDict
|
| 22 |
+
|
| 23 |
+
import google.ai.generativelanguage as glm
|
| 24 |
+
from google.generativeai import protos
|
| 25 |
+
|
| 26 |
+
from google.generativeai.client import (
|
| 27 |
+
get_default_generative_client,
|
| 28 |
+
get_default_generative_async_client,
|
| 29 |
+
)
|
| 30 |
+
from google.generativeai.types import model_types
|
| 31 |
+
from google.generativeai.types import helper_types
|
| 32 |
+
from google.generativeai.types import safety_types
|
| 33 |
+
from google.generativeai.types import content_types
|
| 34 |
+
from google.generativeai.types import retriever_types
|
| 35 |
+
from google.generativeai.types.retriever_types import MetadataFilter
|
| 36 |
+
|
| 37 |
+
DEFAULT_ANSWER_MODEL = "models/aqa"
|
| 38 |
+
|
| 39 |
+
AnswerStyle = protos.GenerateAnswerRequest.AnswerStyle
|
| 40 |
+
|
| 41 |
+
AnswerStyleOptions = Union[int, str, AnswerStyle]
|
| 42 |
+
|
| 43 |
+
_ANSWER_STYLES: dict[AnswerStyleOptions, AnswerStyle] = {
|
| 44 |
+
AnswerStyle.ANSWER_STYLE_UNSPECIFIED: AnswerStyle.ANSWER_STYLE_UNSPECIFIED,
|
| 45 |
+
0: AnswerStyle.ANSWER_STYLE_UNSPECIFIED,
|
| 46 |
+
"answer_style_unspecified": AnswerStyle.ANSWER_STYLE_UNSPECIFIED,
|
| 47 |
+
"unspecified": AnswerStyle.ANSWER_STYLE_UNSPECIFIED,
|
| 48 |
+
AnswerStyle.ABSTRACTIVE: AnswerStyle.ABSTRACTIVE,
|
| 49 |
+
1: AnswerStyle.ABSTRACTIVE,
|
| 50 |
+
"answer_style_abstractive": AnswerStyle.ABSTRACTIVE,
|
| 51 |
+
"abstractive": AnswerStyle.ABSTRACTIVE,
|
| 52 |
+
AnswerStyle.EXTRACTIVE: AnswerStyle.EXTRACTIVE,
|
| 53 |
+
2: AnswerStyle.EXTRACTIVE,
|
| 54 |
+
"answer_style_extractive": AnswerStyle.EXTRACTIVE,
|
| 55 |
+
"extractive": AnswerStyle.EXTRACTIVE,
|
| 56 |
+
AnswerStyle.VERBOSE: AnswerStyle.VERBOSE,
|
| 57 |
+
3: AnswerStyle.VERBOSE,
|
| 58 |
+
"answer_style_verbose": AnswerStyle.VERBOSE,
|
| 59 |
+
"verbose": AnswerStyle.VERBOSE,
|
| 60 |
+
}
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def to_answer_style(x: AnswerStyleOptions) -> AnswerStyle:
|
| 64 |
+
if isinstance(x, str):
|
| 65 |
+
x = x.lower()
|
| 66 |
+
return _ANSWER_STYLES[x]
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
GroundingPassageOptions = (
|
| 70 |
+
Union[
|
| 71 |
+
protos.GroundingPassage, tuple[str, content_types.ContentType], content_types.ContentType
|
| 72 |
+
],
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
GroundingPassagesOptions = Union[
|
| 76 |
+
protos.GroundingPassages,
|
| 77 |
+
Iterable[GroundingPassageOptions],
|
| 78 |
+
Mapping[str, content_types.ContentType],
|
| 79 |
+
]
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def _make_grounding_passages(source: GroundingPassagesOptions) -> protos.GroundingPassages:
|
| 83 |
+
"""
|
| 84 |
+
Converts the `source` into a `protos.GroundingPassage`. A `GroundingPassages` contains a list of
|
| 85 |
+
`protos.GroundingPassage` objects, which each contain a `protos.Content` and a string `id`.
|
| 86 |
+
|
| 87 |
+
Args:
|
| 88 |
+
source: `Content` or a `GroundingPassagesOptions` that will be converted to protos.GroundingPassages.
|
| 89 |
+
|
| 90 |
+
Return:
|
| 91 |
+
`protos.GroundingPassages` to be passed into `protos.GenerateAnswer`.
|
| 92 |
+
"""
|
| 93 |
+
if isinstance(source, protos.GroundingPassages):
|
| 94 |
+
return source
|
| 95 |
+
|
| 96 |
+
if not isinstance(source, Iterable):
|
| 97 |
+
raise TypeError(
|
| 98 |
+
f"Invalid input: The 'source' argument must be an instance of 'GroundingPassagesOptions'. Received a '{type(source).__name__}' object instead."
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
passages = []
|
| 102 |
+
if isinstance(source, Mapping):
|
| 103 |
+
source = source.items()
|
| 104 |
+
|
| 105 |
+
for n, data in enumerate(source):
|
| 106 |
+
if isinstance(data, protos.GroundingPassage):
|
| 107 |
+
passages.append(data)
|
| 108 |
+
elif isinstance(data, tuple):
|
| 109 |
+
id, content = data # tuple must have exactly 2 items.
|
| 110 |
+
passages.append({"id": id, "content": content_types.to_content(content)})
|
| 111 |
+
else:
|
| 112 |
+
passages.append({"id": str(n), "content": content_types.to_content(data)})
|
| 113 |
+
|
| 114 |
+
return protos.GroundingPassages(passages=passages)
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
SourceNameType = Union[
|
| 118 |
+
str, retriever_types.Corpus, protos.Corpus, retriever_types.Document, protos.Document
|
| 119 |
+
]
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
class SemanticRetrieverConfigDict(TypedDict):
|
| 123 |
+
source: SourceNameType
|
| 124 |
+
query: content_types.ContentsType
|
| 125 |
+
metadata_filter: Optional[Iterable[MetadataFilter]]
|
| 126 |
+
max_chunks_count: Optional[int]
|
| 127 |
+
minimum_relevance_score: Optional[float]
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
SemanticRetrieverConfigOptions = Union[
|
| 131 |
+
SourceNameType,
|
| 132 |
+
SemanticRetrieverConfigDict,
|
| 133 |
+
protos.SemanticRetrieverConfig,
|
| 134 |
+
]
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def _maybe_get_source_name(source) -> str | None:
|
| 138 |
+
if isinstance(source, str):
|
| 139 |
+
return source
|
| 140 |
+
elif isinstance(
|
| 141 |
+
source, (retriever_types.Corpus, protos.Corpus, retriever_types.Document, protos.Document)
|
| 142 |
+
):
|
| 143 |
+
return source.name
|
| 144 |
+
else:
|
| 145 |
+
return None
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
def _make_semantic_retriever_config(
|
| 149 |
+
source: SemanticRetrieverConfigOptions,
|
| 150 |
+
query: content_types.ContentsType,
|
| 151 |
+
) -> protos.SemanticRetrieverConfig:
|
| 152 |
+
if isinstance(source, protos.SemanticRetrieverConfig):
|
| 153 |
+
return source
|
| 154 |
+
|
| 155 |
+
name = _maybe_get_source_name(source)
|
| 156 |
+
if name is not None:
|
| 157 |
+
source = {"source": name}
|
| 158 |
+
elif isinstance(source, dict):
|
| 159 |
+
source["source"] = _maybe_get_source_name(source["source"])
|
| 160 |
+
else:
|
| 161 |
+
raise TypeError(
|
| 162 |
+
f"Invalid input: Failed to create a 'protos.SemanticRetrieverConfig' from the provided source. "
|
| 163 |
+
f"Received type: {type(source).__name__}, "
|
| 164 |
+
f"Received value: {source}"
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
if source["query"] is None:
|
| 168 |
+
source["query"] = query
|
| 169 |
+
elif isinstance(source["query"], str):
|
| 170 |
+
source["query"] = content_types.to_content(source["query"])
|
| 171 |
+
|
| 172 |
+
return protos.SemanticRetrieverConfig(source)
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
def _make_generate_answer_request(
|
| 176 |
+
*,
|
| 177 |
+
model: model_types.AnyModelNameOptions = DEFAULT_ANSWER_MODEL,
|
| 178 |
+
contents: content_types.ContentsType,
|
| 179 |
+
inline_passages: GroundingPassagesOptions | None = None,
|
| 180 |
+
semantic_retriever: SemanticRetrieverConfigOptions | None = None,
|
| 181 |
+
answer_style: AnswerStyle | None = None,
|
| 182 |
+
safety_settings: safety_types.SafetySettingOptions | None = None,
|
| 183 |
+
temperature: float | None = None,
|
| 184 |
+
) -> protos.GenerateAnswerRequest:
|
| 185 |
+
"""
|
| 186 |
+
constructs a protos.GenerateAnswerRequest object by organizing the input parameters for the API call to generate a grounded answer from the model.
|
| 187 |
+
|
| 188 |
+
Args:
|
| 189 |
+
model: Name of the model used to generate the grounded response.
|
| 190 |
+
contents: Content of the current conversation with the model. For single-turn query, this is a
|
| 191 |
+
single question to answer. For multi-turn queries, this is a repeated field that contains
|
| 192 |
+
conversation history and the last `Content` in the list containing the question.
|
| 193 |
+
inline_passages: Grounding passages (a list of `Content`-like objects or `(id, content)` pairs,
|
| 194 |
+
or a `protos.GroundingPassages`) to send inline with the request. Exclusive with `semantic_retriever`,
|
| 195 |
+
one must be set, but not both.
|
| 196 |
+
semantic_retriever: A Corpus, Document, or `protos.SemanticRetrieverConfig` to use for grounding. Exclusive with
|
| 197 |
+
`inline_passages`, one must be set, but not both.
|
| 198 |
+
answer_style: Style for grounded answers.
|
| 199 |
+
safety_settings: Safety settings for generated output.
|
| 200 |
+
temperature: The temperature for randomness in the output.
|
| 201 |
+
|
| 202 |
+
Returns:
|
| 203 |
+
Call for protos.GenerateAnswerRequest().
|
| 204 |
+
"""
|
| 205 |
+
model = model_types.make_model_name(model)
|
| 206 |
+
|
| 207 |
+
contents = content_types.to_contents(contents)
|
| 208 |
+
|
| 209 |
+
if safety_settings:
|
| 210 |
+
safety_settings = safety_types.normalize_safety_settings(safety_settings)
|
| 211 |
+
|
| 212 |
+
if inline_passages is not None and semantic_retriever is not None:
|
| 213 |
+
raise ValueError(
|
| 214 |
+
f"Invalid configuration: Please set either 'inline_passages' or 'semantic_retriever_config', but not both. "
|
| 215 |
+
f"Received for inline_passages: {inline_passages}, and for semantic_retriever: {semantic_retriever}."
|
| 216 |
+
)
|
| 217 |
+
elif inline_passages is not None:
|
| 218 |
+
inline_passages = _make_grounding_passages(inline_passages)
|
| 219 |
+
elif semantic_retriever is not None:
|
| 220 |
+
semantic_retriever = _make_semantic_retriever_config(semantic_retriever, contents[-1])
|
| 221 |
+
else:
|
| 222 |
+
raise TypeError(
|
| 223 |
+
f"Invalid configuration: Either 'inline_passages' or 'semantic_retriever_config' must be provided, but currently both are 'None'. "
|
| 224 |
+
f"Received for inline_passages: {inline_passages}, and for semantic_retriever: {semantic_retriever}."
|
| 225 |
+
)
|
| 226 |
+
|
| 227 |
+
if answer_style:
|
| 228 |
+
answer_style = to_answer_style(answer_style)
|
| 229 |
+
|
| 230 |
+
return protos.GenerateAnswerRequest(
|
| 231 |
+
model=model,
|
| 232 |
+
contents=contents,
|
| 233 |
+
inline_passages=inline_passages,
|
| 234 |
+
semantic_retriever=semantic_retriever,
|
| 235 |
+
safety_settings=safety_settings,
|
| 236 |
+
temperature=temperature,
|
| 237 |
+
answer_style=answer_style,
|
| 238 |
+
)
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
def generate_answer(
|
| 242 |
+
*,
|
| 243 |
+
model: model_types.AnyModelNameOptions = DEFAULT_ANSWER_MODEL,
|
| 244 |
+
contents: content_types.ContentsType,
|
| 245 |
+
inline_passages: GroundingPassagesOptions | None = None,
|
| 246 |
+
semantic_retriever: SemanticRetrieverConfigOptions | None = None,
|
| 247 |
+
answer_style: AnswerStyle | None = None,
|
| 248 |
+
safety_settings: safety_types.SafetySettingOptions | None = None,
|
| 249 |
+
temperature: float | None = None,
|
| 250 |
+
client: glm.GenerativeServiceClient | None = None,
|
| 251 |
+
request_options: helper_types.RequestOptionsType | None = None,
|
| 252 |
+
):
|
| 253 |
+
"""Calls the GenerateAnswer API and returns a `types.Answer` containing the response.
|
| 254 |
+
|
| 255 |
+
You can pass a literal list of text chunks:
|
| 256 |
+
|
| 257 |
+
>>> from google.generativeai import answer
|
| 258 |
+
>>> answer.generate_answer(
|
| 259 |
+
... content=question,
|
| 260 |
+
... inline_passages=splitter.split(document)
|
| 261 |
+
... )
|
| 262 |
+
|
| 263 |
+
Or pass a reference to a retreiver Document or Corpus:
|
| 264 |
+
|
| 265 |
+
>>> from google.generativeai import answer
|
| 266 |
+
>>> from google.generativeai import retriever
|
| 267 |
+
>>> my_corpus = retriever.get_corpus('my_corpus')
|
| 268 |
+
>>> genai.generate_answer(
|
| 269 |
+
... content=question,
|
| 270 |
+
... semantic_retriever=my_corpus
|
| 271 |
+
... )
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
Args:
|
| 275 |
+
model: Which model to call, as a string or a `types.Model`.
|
| 276 |
+
contents: The question to be answered by the model, grounded in the
|
| 277 |
+
provided source.
|
| 278 |
+
inline_passages: Grounding passages (a list of `Content`-like objects or (id, content) pairs,
|
| 279 |
+
or a `protos.GroundingPassages`) to send inline with the request. Exclusive with `semantic_retriever`,
|
| 280 |
+
one must be set, but not both.
|
| 281 |
+
semantic_retriever: A Corpus, Document, or `protos.SemanticRetrieverConfig` to use for grounding. Exclusive with
|
| 282 |
+
`inline_passages`, one must be set, but not both.
|
| 283 |
+
answer_style: Style in which the grounded answer should be returned.
|
| 284 |
+
safety_settings: Safety settings for generated output. Defaults to None.
|
| 285 |
+
temperature: Controls the randomness of the output.
|
| 286 |
+
client: If you're not relying on a default client, you pass a `glm.GenerativeServiceClient` instead.
|
| 287 |
+
request_options: Options for the request.
|
| 288 |
+
|
| 289 |
+
Returns:
|
| 290 |
+
A `types.Answer` containing the model's text answer response.
|
| 291 |
+
"""
|
| 292 |
+
if request_options is None:
|
| 293 |
+
request_options = {}
|
| 294 |
+
|
| 295 |
+
if client is None:
|
| 296 |
+
client = get_default_generative_client()
|
| 297 |
+
|
| 298 |
+
request = _make_generate_answer_request(
|
| 299 |
+
model=model,
|
| 300 |
+
contents=contents,
|
| 301 |
+
inline_passages=inline_passages,
|
| 302 |
+
semantic_retriever=semantic_retriever,
|
| 303 |
+
safety_settings=safety_settings,
|
| 304 |
+
temperature=temperature,
|
| 305 |
+
answer_style=answer_style,
|
| 306 |
+
)
|
| 307 |
+
|
| 308 |
+
response = client.generate_answer(request, **request_options)
|
| 309 |
+
|
| 310 |
+
return response
|
| 311 |
+
|
| 312 |
+
|
| 313 |
+
async def generate_answer_async(
|
| 314 |
+
*,
|
| 315 |
+
model: model_types.AnyModelNameOptions = DEFAULT_ANSWER_MODEL,
|
| 316 |
+
contents: content_types.ContentsType,
|
| 317 |
+
inline_passages: GroundingPassagesOptions | None = None,
|
| 318 |
+
semantic_retriever: SemanticRetrieverConfigOptions | None = None,
|
| 319 |
+
answer_style: AnswerStyle | None = None,
|
| 320 |
+
safety_settings: safety_types.SafetySettingOptions | None = None,
|
| 321 |
+
temperature: float | None = None,
|
| 322 |
+
client: glm.GenerativeServiceClient | None = None,
|
| 323 |
+
request_options: helper_types.RequestOptionsType | None = None,
|
| 324 |
+
):
|
| 325 |
+
"""
|
| 326 |
+
Calls the API and returns a `types.Answer` containing the answer.
|
| 327 |
+
|
| 328 |
+
Args:
|
| 329 |
+
model: Which model to call, as a string or a `types.Model`.
|
| 330 |
+
contents: The question to be answered by the model, grounded in the
|
| 331 |
+
provided source.
|
| 332 |
+
inline_passages: Grounding passages (a list of `Content`-like objects or (id, content) pairs,
|
| 333 |
+
or a `protos.GroundingPassages`) to send inline with the request. Exclusive with `semantic_retriever`,
|
| 334 |
+
one must be set, but not both.
|
| 335 |
+
semantic_retriever: A Corpus, Document, or `protos.SemanticRetrieverConfig` to use for grounding. Exclusive with
|
| 336 |
+
`inline_passages`, one must be set, but not both.
|
| 337 |
+
answer_style: Style in which the grounded answer should be returned.
|
| 338 |
+
safety_settings: Safety settings for generated output. Defaults to None.
|
| 339 |
+
temperature: Controls the randomness of the output.
|
| 340 |
+
client: If you're not relying on a default client, you pass a `glm.GenerativeServiceClient` instead.
|
| 341 |
+
|
| 342 |
+
Returns:
|
| 343 |
+
A `types.Answer` containing the model's text answer response.
|
| 344 |
+
"""
|
| 345 |
+
if request_options is None:
|
| 346 |
+
request_options = {}
|
| 347 |
+
|
| 348 |
+
if client is None:
|
| 349 |
+
client = get_default_generative_async_client()
|
| 350 |
+
|
| 351 |
+
request = _make_generate_answer_request(
|
| 352 |
+
model=model,
|
| 353 |
+
contents=contents,
|
| 354 |
+
inline_passages=inline_passages,
|
| 355 |
+
semantic_retriever=semantic_retriever,
|
| 356 |
+
safety_settings=safety_settings,
|
| 357 |
+
temperature=temperature,
|
| 358 |
+
answer_style=answer_style,
|
| 359 |
+
)
|
| 360 |
+
|
| 361 |
+
response = await client.generate_answer(request, **request_options)
|
| 362 |
+
|
| 363 |
+
return response
|
.venv/lib/python3.11/site-packages/google/generativeai/audio_models/__init__.py
ADDED
|
File without changes
|
.venv/lib/python3.11/site-packages/google/generativeai/audio_models/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (205 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/google/generativeai/audio_models/__pycache__/_audio_models.cpython-311.pyc
ADDED
|
Binary file (629 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/google/generativeai/audio_models/_audio_models.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
class SpeechGenerationModel:
|
| 2 |
+
def __init__(self, model_name):
|
| 3 |
+
self.model_name = model_name
|
.venv/lib/python3.11/site-packages/google/generativeai/caching.py
ADDED
|
@@ -0,0 +1,314 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Copyright 2024 Google LLC
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
from __future__ import annotations
|
| 16 |
+
|
| 17 |
+
import datetime
|
| 18 |
+
import textwrap
|
| 19 |
+
from typing import Iterable, Optional
|
| 20 |
+
|
| 21 |
+
from google.generativeai import protos
|
| 22 |
+
from google.generativeai.types import caching_types
|
| 23 |
+
from google.generativeai.types import content_types
|
| 24 |
+
from google.generativeai.client import get_default_cache_client
|
| 25 |
+
|
| 26 |
+
from google.protobuf import field_mask_pb2
|
| 27 |
+
|
| 28 |
+
_USER_ROLE = "user"
|
| 29 |
+
_MODEL_ROLE = "model"
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class CachedContent:
|
| 33 |
+
"""Cached content resource."""
|
| 34 |
+
|
| 35 |
+
def __init__(self, name):
|
| 36 |
+
"""Fetches a `CachedContent` resource.
|
| 37 |
+
|
| 38 |
+
Identical to `CachedContent.get`.
|
| 39 |
+
|
| 40 |
+
Args:
|
| 41 |
+
name: The resource name referring to the cached content.
|
| 42 |
+
"""
|
| 43 |
+
client = get_default_cache_client()
|
| 44 |
+
|
| 45 |
+
if "cachedContents/" not in name:
|
| 46 |
+
name = "cachedContents/" + name
|
| 47 |
+
|
| 48 |
+
request = protos.GetCachedContentRequest(name=name)
|
| 49 |
+
response = client.get_cached_content(request)
|
| 50 |
+
self._proto = response
|
| 51 |
+
|
| 52 |
+
@property
|
| 53 |
+
def name(self) -> str:
|
| 54 |
+
return self._proto.name
|
| 55 |
+
|
| 56 |
+
@property
|
| 57 |
+
def model(self) -> str:
|
| 58 |
+
return self._proto.model
|
| 59 |
+
|
| 60 |
+
@property
|
| 61 |
+
def display_name(self) -> str:
|
| 62 |
+
return self._proto.display_name
|
| 63 |
+
|
| 64 |
+
@property
|
| 65 |
+
def usage_metadata(self) -> protos.CachedContent.UsageMetadata:
|
| 66 |
+
return self._proto.usage_metadata
|
| 67 |
+
|
| 68 |
+
@property
|
| 69 |
+
def create_time(self) -> datetime.datetime:
|
| 70 |
+
return self._proto.create_time
|
| 71 |
+
|
| 72 |
+
@property
|
| 73 |
+
def update_time(self) -> datetime.datetime:
|
| 74 |
+
return self._proto.update_time
|
| 75 |
+
|
| 76 |
+
@property
|
| 77 |
+
def expire_time(self) -> datetime.datetime:
|
| 78 |
+
return self._proto.expire_time
|
| 79 |
+
|
| 80 |
+
def __str__(self):
|
| 81 |
+
return textwrap.dedent(
|
| 82 |
+
f"""\
|
| 83 |
+
CachedContent(
|
| 84 |
+
name='{self.name}',
|
| 85 |
+
model='{self.model}',
|
| 86 |
+
display_name='{self.display_name}',
|
| 87 |
+
usage_metadata={'{'}
|
| 88 |
+
'total_token_count': {self.usage_metadata.total_token_count},
|
| 89 |
+
{'}'},
|
| 90 |
+
create_time={self.create_time},
|
| 91 |
+
update_time={self.update_time},
|
| 92 |
+
expire_time={self.expire_time}
|
| 93 |
+
)"""
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
__repr__ = __str__
|
| 97 |
+
|
| 98 |
+
@classmethod
|
| 99 |
+
def _from_obj(cls, obj: CachedContent | protos.CachedContent | dict) -> CachedContent:
|
| 100 |
+
"""Creates an instance of CachedContent form an object, without calling `get`."""
|
| 101 |
+
self = cls.__new__(cls)
|
| 102 |
+
self._proto = protos.CachedContent()
|
| 103 |
+
self._update(obj)
|
| 104 |
+
return self
|
| 105 |
+
|
| 106 |
+
def _update(self, updates):
|
| 107 |
+
"""Updates this instance inplace, does not call the API's `update` method"""
|
| 108 |
+
if isinstance(updates, CachedContent):
|
| 109 |
+
updates = updates._proto
|
| 110 |
+
|
| 111 |
+
if not isinstance(updates, dict):
|
| 112 |
+
updates = type(updates).to_dict(updates, including_default_value_fields=False)
|
| 113 |
+
|
| 114 |
+
for key, value in updates.items():
|
| 115 |
+
setattr(self._proto, key, value)
|
| 116 |
+
|
| 117 |
+
@staticmethod
|
| 118 |
+
def _prepare_create_request(
|
| 119 |
+
model: str,
|
| 120 |
+
*,
|
| 121 |
+
display_name: str | None = None,
|
| 122 |
+
system_instruction: Optional[content_types.ContentType] = None,
|
| 123 |
+
contents: Optional[content_types.ContentsType] = None,
|
| 124 |
+
tools: Optional[content_types.FunctionLibraryType] = None,
|
| 125 |
+
tool_config: Optional[content_types.ToolConfigType] = None,
|
| 126 |
+
ttl: Optional[caching_types.TTLTypes] = None,
|
| 127 |
+
expire_time: Optional[caching_types.ExpireTimeTypes] = None,
|
| 128 |
+
) -> protos.CreateCachedContentRequest:
|
| 129 |
+
"""Prepares a CreateCachedContentRequest."""
|
| 130 |
+
if ttl and expire_time:
|
| 131 |
+
raise ValueError(
|
| 132 |
+
"Exclusive arguments: Please provide either `ttl` or `expire_time`, not both."
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
if "/" not in model:
|
| 136 |
+
model = "models/" + model
|
| 137 |
+
|
| 138 |
+
if display_name and len(display_name) > 128:
|
| 139 |
+
raise ValueError("`display_name` must be no more than 128 unicode characters.")
|
| 140 |
+
|
| 141 |
+
if system_instruction:
|
| 142 |
+
system_instruction = content_types.to_content(system_instruction)
|
| 143 |
+
|
| 144 |
+
tools_lib = content_types.to_function_library(tools)
|
| 145 |
+
if tools_lib:
|
| 146 |
+
tools_lib = tools_lib.to_proto()
|
| 147 |
+
|
| 148 |
+
if tool_config:
|
| 149 |
+
tool_config = content_types.to_tool_config(tool_config)
|
| 150 |
+
|
| 151 |
+
if contents:
|
| 152 |
+
contents = content_types.to_contents(contents)
|
| 153 |
+
if not contents[-1].role:
|
| 154 |
+
contents[-1].role = _USER_ROLE
|
| 155 |
+
|
| 156 |
+
ttl = caching_types.to_optional_ttl(ttl)
|
| 157 |
+
expire_time = caching_types.to_optional_expire_time(expire_time)
|
| 158 |
+
|
| 159 |
+
cached_content = protos.CachedContent(
|
| 160 |
+
model=model,
|
| 161 |
+
display_name=display_name,
|
| 162 |
+
system_instruction=system_instruction,
|
| 163 |
+
contents=contents,
|
| 164 |
+
tools=tools_lib,
|
| 165 |
+
tool_config=tool_config,
|
| 166 |
+
ttl=ttl,
|
| 167 |
+
expire_time=expire_time,
|
| 168 |
+
)
|
| 169 |
+
|
| 170 |
+
return protos.CreateCachedContentRequest(cached_content=cached_content)
|
| 171 |
+
|
| 172 |
+
@classmethod
|
| 173 |
+
def create(
|
| 174 |
+
cls,
|
| 175 |
+
model: str,
|
| 176 |
+
*,
|
| 177 |
+
display_name: str | None = None,
|
| 178 |
+
system_instruction: Optional[content_types.ContentType] = None,
|
| 179 |
+
contents: Optional[content_types.ContentsType] = None,
|
| 180 |
+
tools: Optional[content_types.FunctionLibraryType] = None,
|
| 181 |
+
tool_config: Optional[content_types.ToolConfigType] = None,
|
| 182 |
+
ttl: Optional[caching_types.TTLTypes] = None,
|
| 183 |
+
expire_time: Optional[caching_types.ExpireTimeTypes] = None,
|
| 184 |
+
) -> CachedContent:
|
| 185 |
+
"""Creates `CachedContent` resource.
|
| 186 |
+
|
| 187 |
+
Args:
|
| 188 |
+
model: The name of the `model` to use for cached content creation.
|
| 189 |
+
Any `CachedContent` resource can be only used with the
|
| 190 |
+
`model` it was created for.
|
| 191 |
+
display_name: The user-generated meaningful display name
|
| 192 |
+
of the cached content. `display_name` must be no
|
| 193 |
+
more than 128 unicode characters.
|
| 194 |
+
system_instruction: Developer set system instruction.
|
| 195 |
+
contents: Contents to cache.
|
| 196 |
+
tools: A list of `Tools` the model may use to generate response.
|
| 197 |
+
tool_config: Config to apply to all tools.
|
| 198 |
+
ttl: TTL for cached resource (in seconds). Defaults to 1 hour.
|
| 199 |
+
`ttl` and `expire_time` are exclusive arguments.
|
| 200 |
+
expire_time: Expiration time for cached resource.
|
| 201 |
+
`ttl` and `expire_time` are exclusive arguments.
|
| 202 |
+
|
| 203 |
+
Returns:
|
| 204 |
+
`CachedContent` resource with specified name.
|
| 205 |
+
"""
|
| 206 |
+
client = get_default_cache_client()
|
| 207 |
+
|
| 208 |
+
request = cls._prepare_create_request(
|
| 209 |
+
model=model,
|
| 210 |
+
display_name=display_name,
|
| 211 |
+
system_instruction=system_instruction,
|
| 212 |
+
contents=contents,
|
| 213 |
+
tools=tools,
|
| 214 |
+
tool_config=tool_config,
|
| 215 |
+
ttl=ttl,
|
| 216 |
+
expire_time=expire_time,
|
| 217 |
+
)
|
| 218 |
+
|
| 219 |
+
response = client.create_cached_content(request)
|
| 220 |
+
result = CachedContent._from_obj(response)
|
| 221 |
+
return result
|
| 222 |
+
|
| 223 |
+
@classmethod
|
| 224 |
+
def get(cls, name: str) -> CachedContent:
|
| 225 |
+
"""Fetches required `CachedContent` resource.
|
| 226 |
+
|
| 227 |
+
Args:
|
| 228 |
+
name: The resource name referring to the cached content.
|
| 229 |
+
|
| 230 |
+
Returns:
|
| 231 |
+
`CachedContent` resource with specified `name`.
|
| 232 |
+
"""
|
| 233 |
+
client = get_default_cache_client()
|
| 234 |
+
|
| 235 |
+
if "cachedContents/" not in name:
|
| 236 |
+
name = "cachedContents/" + name
|
| 237 |
+
|
| 238 |
+
request = protos.GetCachedContentRequest(name=name)
|
| 239 |
+
response = client.get_cached_content(request)
|
| 240 |
+
result = CachedContent._from_obj(response)
|
| 241 |
+
return result
|
| 242 |
+
|
| 243 |
+
@classmethod
|
| 244 |
+
def list(cls, page_size: Optional[int] = 1) -> Iterable[CachedContent]:
|
| 245 |
+
"""Lists `CachedContent` objects associated with the project.
|
| 246 |
+
|
| 247 |
+
Args:
|
| 248 |
+
page_size: The maximum number of permissions to return (per page).
|
| 249 |
+
The service may return fewer `CachedContent` objects.
|
| 250 |
+
|
| 251 |
+
Returns:
|
| 252 |
+
A paginated list of `CachedContent` objects.
|
| 253 |
+
"""
|
| 254 |
+
client = get_default_cache_client()
|
| 255 |
+
|
| 256 |
+
request = protos.ListCachedContentsRequest(page_size=page_size)
|
| 257 |
+
for cached_content in client.list_cached_contents(request):
|
| 258 |
+
cached_content = CachedContent._from_obj(cached_content)
|
| 259 |
+
yield cached_content
|
| 260 |
+
|
| 261 |
+
def delete(self) -> None:
|
| 262 |
+
"""Deletes `CachedContent` resource."""
|
| 263 |
+
client = get_default_cache_client()
|
| 264 |
+
|
| 265 |
+
request = protos.DeleteCachedContentRequest(name=self.name)
|
| 266 |
+
client.delete_cached_content(request)
|
| 267 |
+
return
|
| 268 |
+
|
| 269 |
+
def update(
|
| 270 |
+
self,
|
| 271 |
+
*,
|
| 272 |
+
ttl: Optional[caching_types.TTLTypes] = None,
|
| 273 |
+
expire_time: Optional[caching_types.ExpireTimeTypes] = None,
|
| 274 |
+
) -> None:
|
| 275 |
+
"""Updates requested `CachedContent` resource.
|
| 276 |
+
|
| 277 |
+
Args:
|
| 278 |
+
ttl: TTL for cached resource (in seconds). Defaults to 1 hour.
|
| 279 |
+
`ttl` and `expire_time` are exclusive arguments.
|
| 280 |
+
expire_time: Expiration time for cached resource.
|
| 281 |
+
`ttl` and `expire_time` are exclusive arguments.
|
| 282 |
+
"""
|
| 283 |
+
client = get_default_cache_client()
|
| 284 |
+
|
| 285 |
+
if ttl and expire_time:
|
| 286 |
+
raise ValueError(
|
| 287 |
+
"Exclusive arguments: Please provide either `ttl` or `expire_time`, not both."
|
| 288 |
+
)
|
| 289 |
+
|
| 290 |
+
ttl = caching_types.to_optional_ttl(ttl)
|
| 291 |
+
expire_time = caching_types.to_optional_expire_time(expire_time)
|
| 292 |
+
|
| 293 |
+
updates = protos.CachedContent(
|
| 294 |
+
name=self.name,
|
| 295 |
+
ttl=ttl,
|
| 296 |
+
expire_time=expire_time,
|
| 297 |
+
)
|
| 298 |
+
|
| 299 |
+
field_mask = field_mask_pb2.FieldMask()
|
| 300 |
+
|
| 301 |
+
if ttl:
|
| 302 |
+
field_mask.paths.append("ttl")
|
| 303 |
+
elif expire_time:
|
| 304 |
+
field_mask.paths.append("expire_time")
|
| 305 |
+
else:
|
| 306 |
+
raise ValueError(
|
| 307 |
+
f"Bad update name: Only `ttl` or `expire_time` can be updated for `CachedContent`."
|
| 308 |
+
)
|
| 309 |
+
|
| 310 |
+
request = protos.UpdateCachedContentRequest(cached_content=updates, update_mask=field_mask)
|
| 311 |
+
updated_cc = client.update_cached_content(request)
|
| 312 |
+
self._update(updated_cc)
|
| 313 |
+
|
| 314 |
+
return
|
.venv/lib/python3.11/site-packages/google/generativeai/client.py
ADDED
|
@@ -0,0 +1,388 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
import contextlib
|
| 5 |
+
import inspect
|
| 6 |
+
import dataclasses
|
| 7 |
+
import pathlib
|
| 8 |
+
import threading
|
| 9 |
+
from typing import Any, cast
|
| 10 |
+
from collections.abc import Sequence
|
| 11 |
+
import httplib2
|
| 12 |
+
from io import IOBase
|
| 13 |
+
|
| 14 |
+
import google.ai.generativelanguage as glm
|
| 15 |
+
import google.generativeai.protos as protos
|
| 16 |
+
|
| 17 |
+
from google.auth import credentials as ga_credentials
|
| 18 |
+
from google.auth import exceptions as ga_exceptions
|
| 19 |
+
from google import auth
|
| 20 |
+
from google.api_core import client_options as client_options_lib
|
| 21 |
+
from google.api_core import gapic_v1
|
| 22 |
+
from google.api_core import operations_v1
|
| 23 |
+
|
| 24 |
+
import googleapiclient.http
|
| 25 |
+
import googleapiclient.discovery
|
| 26 |
+
|
| 27 |
+
try:
|
| 28 |
+
from google.generativeai import version
|
| 29 |
+
|
| 30 |
+
__version__ = version.__version__
|
| 31 |
+
except ImportError:
|
| 32 |
+
__version__ = "0.0.0"
|
| 33 |
+
|
| 34 |
+
USER_AGENT = "genai-py"
|
| 35 |
+
|
| 36 |
+
#### Caution! ####
|
| 37 |
+
# - It would make sense for the discovery URL to respect the client_options.endpoint setting.
|
| 38 |
+
# - That would make testing Files on the staging server possible.
|
| 39 |
+
# - We tried fixing this once, but broke colab in the process because their endpoint didn't forward the discovery
|
| 40 |
+
# requests. https://github.com/google-gemini/generative-ai-python/pull/333
|
| 41 |
+
# - Kaggle would have a similar problem (b/362278209).
|
| 42 |
+
# - I think their proxy would forward the discovery traffic.
|
| 43 |
+
# - But they don't need to intercept the files-service at all, and uploads of large files could overload them.
|
| 44 |
+
# - Do the scotty uploads go to the same domain?
|
| 45 |
+
# - If you do route the discovery call to kaggle, be sure to attach the default_metadata (they need it).
|
| 46 |
+
# - One solution to all this would be if configure could take overrides per service.
|
| 47 |
+
# - set client_options.endpoint, but use a different endpoint for file service? It's not clear how best to do that
|
| 48 |
+
# through the file service.
|
| 49 |
+
##################
|
| 50 |
+
GENAI_API_DISCOVERY_URL = "https://generativelanguage.googleapis.com/$discovery/rest"
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
@contextlib.contextmanager
|
| 54 |
+
def patch_colab_gce_credentials():
|
| 55 |
+
get_gce = auth._default._get_gce_credentials
|
| 56 |
+
if "COLAB_RELEASE_TAG" in os.environ:
|
| 57 |
+
auth._default._get_gce_credentials = lambda *args, **kwargs: (None, None)
|
| 58 |
+
|
| 59 |
+
try:
|
| 60 |
+
yield
|
| 61 |
+
finally:
|
| 62 |
+
auth._default._get_gce_credentials = get_gce
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
class FileServiceClient(glm.FileServiceClient):
|
| 66 |
+
def __init__(self, *args, **kwargs):
|
| 67 |
+
self._discovery_api = None
|
| 68 |
+
self._local = threading.local()
|
| 69 |
+
super().__init__(*args, **kwargs)
|
| 70 |
+
|
| 71 |
+
def _setup_discovery_api(self, metadata: dict | Sequence[tuple[str, str]] = ()):
|
| 72 |
+
api_key = self._client_options.api_key
|
| 73 |
+
if api_key is None:
|
| 74 |
+
raise ValueError(
|
| 75 |
+
"Invalid operation: Uploading to the File API requires an API key. Please provide a valid API key."
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
request = googleapiclient.http.HttpRequest(
|
| 79 |
+
http=httplib2.Http(),
|
| 80 |
+
postproc=lambda resp, content: (resp, content),
|
| 81 |
+
uri=f"{GENAI_API_DISCOVERY_URL}?version=v1beta&key={api_key}",
|
| 82 |
+
headers=dict(metadata),
|
| 83 |
+
)
|
| 84 |
+
response, content = request.execute()
|
| 85 |
+
request.http.close()
|
| 86 |
+
|
| 87 |
+
discovery_doc = content.decode("utf-8")
|
| 88 |
+
self._local.discovery_api = googleapiclient.discovery.build_from_document(
|
| 89 |
+
discovery_doc, developerKey=api_key
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
def create_file(
|
| 93 |
+
self,
|
| 94 |
+
path: str | pathlib.Path | os.PathLike | IOBase,
|
| 95 |
+
*,
|
| 96 |
+
mime_type: str | None = None,
|
| 97 |
+
name: str | None = None,
|
| 98 |
+
display_name: str | None = None,
|
| 99 |
+
resumable: bool = True,
|
| 100 |
+
metadata: Sequence[tuple[str, str]] = (),
|
| 101 |
+
) -> protos.File:
|
| 102 |
+
if self._discovery_api is None:
|
| 103 |
+
self._setup_discovery_api(metadata)
|
| 104 |
+
|
| 105 |
+
file = {}
|
| 106 |
+
if name is not None:
|
| 107 |
+
file["name"] = name
|
| 108 |
+
if display_name is not None:
|
| 109 |
+
file["displayName"] = display_name
|
| 110 |
+
|
| 111 |
+
if isinstance(path, IOBase):
|
| 112 |
+
media = googleapiclient.http.MediaIoBaseUpload(
|
| 113 |
+
fd=path, mimetype=mime_type, resumable=resumable
|
| 114 |
+
)
|
| 115 |
+
else:
|
| 116 |
+
media = googleapiclient.http.MediaFileUpload(
|
| 117 |
+
filename=path, mimetype=mime_type, resumable=resumable
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
request = self._local.discovery_api.media().upload(body={"file": file}, media_body=media)
|
| 121 |
+
for key, value in metadata:
|
| 122 |
+
request.headers[key] = value
|
| 123 |
+
result = request.execute()
|
| 124 |
+
|
| 125 |
+
return self.get_file({"name": result["file"]["name"]})
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
class FileServiceAsyncClient(glm.FileServiceAsyncClient):
|
| 129 |
+
async def create_file(self, *args, **kwargs):
|
| 130 |
+
raise NotImplementedError(
|
| 131 |
+
"The `create_file` method is currently not supported for the asynchronous client."
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
@dataclasses.dataclass
|
| 136 |
+
class _ClientManager:
|
| 137 |
+
client_config: dict[str, Any] = dataclasses.field(default_factory=dict)
|
| 138 |
+
default_metadata: Sequence[tuple[str, str]] = ()
|
| 139 |
+
clients: dict[str, Any] = dataclasses.field(default_factory=dict)
|
| 140 |
+
|
| 141 |
+
def configure(
|
| 142 |
+
self,
|
| 143 |
+
*,
|
| 144 |
+
api_key: str | None = None,
|
| 145 |
+
credentials: ga_credentials.Credentials | dict | None = None,
|
| 146 |
+
# The user can pass a string to choose `rest` or `grpc` or 'grpc_asyncio'.
|
| 147 |
+
# See _transport_registry in the google.ai.generativelanguage package.
|
| 148 |
+
# Since the transport classes align with the client classes it wouldn't make
|
| 149 |
+
# sense to accept a `Transport` object here even though the client classes can.
|
| 150 |
+
# We could accept a dict since all the `Transport` classes take the same args,
|
| 151 |
+
# but that seems rare. Users that need it can just switch to the low level API.
|
| 152 |
+
transport: str | None = None,
|
| 153 |
+
client_options: client_options_lib.ClientOptions | dict[str, Any] | None = None,
|
| 154 |
+
client_info: gapic_v1.client_info.ClientInfo | None = None,
|
| 155 |
+
default_metadata: Sequence[tuple[str, str]] = (),
|
| 156 |
+
) -> None:
|
| 157 |
+
"""Initializes default client configurations using specified parameters or environment variables.
|
| 158 |
+
|
| 159 |
+
If no API key has been provided (either directly, or on `client_options`) and the
|
| 160 |
+
`GEMINI_API_KEY` environment variable is set, it will be used as the API key. If not,
|
| 161 |
+
if the `GOOGLE_API_KEY` environement variable is set, it will be used as the API key.
|
| 162 |
+
|
| 163 |
+
Note: Not all arguments are detailed below. Refer to the `*ServiceClient` classes in
|
| 164 |
+
`google.ai.generativelanguage` for details on the other arguments.
|
| 165 |
+
|
| 166 |
+
Args:
|
| 167 |
+
transport: A string, one of: [`rest`, `grpc`, `grpc_asyncio`].
|
| 168 |
+
api_key: The API-Key to use when creating the default clients (each service uses
|
| 169 |
+
a separate client). This is a shortcut for `client_options={"api_key": api_key}`.
|
| 170 |
+
If omitted, and the `GEMINI_API_KEY` or the `GOOGLE_API_KEY` environment variable
|
| 171 |
+
are set, they will be used in this order of priority.
|
| 172 |
+
default_metadata: Default (key, value) metadata pairs to send with every request.
|
| 173 |
+
when using `transport="rest"` these are sent as HTTP headers.
|
| 174 |
+
"""
|
| 175 |
+
if isinstance(client_options, dict):
|
| 176 |
+
client_options = client_options_lib.from_dict(client_options)
|
| 177 |
+
if client_options is None:
|
| 178 |
+
client_options = client_options_lib.ClientOptions()
|
| 179 |
+
client_options = cast(client_options_lib.ClientOptions, client_options)
|
| 180 |
+
had_api_key_value = getattr(client_options, "api_key", None)
|
| 181 |
+
|
| 182 |
+
if had_api_key_value:
|
| 183 |
+
if api_key is not None:
|
| 184 |
+
raise ValueError(
|
| 185 |
+
"Invalid configuration: Please set either `api_key` or `client_options['api_key']`, but not both."
|
| 186 |
+
)
|
| 187 |
+
else:
|
| 188 |
+
if api_key is None:
|
| 189 |
+
# If no key is provided explicitly, attempt to load one from the
|
| 190 |
+
# environment.
|
| 191 |
+
api_key = os.getenv("GEMINI_API_KEY")
|
| 192 |
+
|
| 193 |
+
if api_key is None:
|
| 194 |
+
# If the GEMINI_API_KEY doesn't exist, attempt to load the
|
| 195 |
+
# GOOGLE_API_KEY from the environment.
|
| 196 |
+
api_key = os.getenv("GOOGLE_API_KEY")
|
| 197 |
+
|
| 198 |
+
client_options.api_key = api_key
|
| 199 |
+
|
| 200 |
+
user_agent = f"{USER_AGENT}/{__version__}"
|
| 201 |
+
if client_info:
|
| 202 |
+
# Be respectful of any existing agent setting.
|
| 203 |
+
if client_info.user_agent:
|
| 204 |
+
client_info.user_agent += f" {user_agent}"
|
| 205 |
+
else:
|
| 206 |
+
client_info.user_agent = user_agent
|
| 207 |
+
else:
|
| 208 |
+
client_info = gapic_v1.client_info.ClientInfo(user_agent=user_agent)
|
| 209 |
+
|
| 210 |
+
client_config = {
|
| 211 |
+
"credentials": credentials,
|
| 212 |
+
"transport": transport,
|
| 213 |
+
"client_options": client_options,
|
| 214 |
+
"client_info": client_info,
|
| 215 |
+
}
|
| 216 |
+
|
| 217 |
+
client_config = {key: value for key, value in client_config.items() if value is not None}
|
| 218 |
+
|
| 219 |
+
self.client_config = client_config
|
| 220 |
+
self.default_metadata = default_metadata
|
| 221 |
+
|
| 222 |
+
self.clients = {}
|
| 223 |
+
|
| 224 |
+
def make_client(self, name):
|
| 225 |
+
if name == "file":
|
| 226 |
+
cls = FileServiceClient
|
| 227 |
+
elif name == "file_async":
|
| 228 |
+
cls = FileServiceAsyncClient
|
| 229 |
+
elif name.endswith("_async"):
|
| 230 |
+
name = name.split("_")[0]
|
| 231 |
+
cls = getattr(glm, name.title() + "ServiceAsyncClient")
|
| 232 |
+
else:
|
| 233 |
+
cls = getattr(glm, name.title() + "ServiceClient")
|
| 234 |
+
|
| 235 |
+
# Attempt to configure using defaults.
|
| 236 |
+
if not self.client_config:
|
| 237 |
+
configure()
|
| 238 |
+
|
| 239 |
+
try:
|
| 240 |
+
with patch_colab_gce_credentials():
|
| 241 |
+
client = cls(**self.client_config)
|
| 242 |
+
except ga_exceptions.DefaultCredentialsError as e:
|
| 243 |
+
e.args = (
|
| 244 |
+
"\n No API_KEY or ADC found. Please either:\n"
|
| 245 |
+
" - Set the `GOOGLE_API_KEY` environment variable.\n"
|
| 246 |
+
" - Manually pass the key with `genai.configure(api_key=my_api_key)`.\n"
|
| 247 |
+
" - Or set up Application Default Credentials, see https://ai.google.dev/gemini-api/docs/oauth for more information.",
|
| 248 |
+
)
|
| 249 |
+
raise e
|
| 250 |
+
|
| 251 |
+
if not self.default_metadata:
|
| 252 |
+
return client
|
| 253 |
+
|
| 254 |
+
def keep(name, f):
|
| 255 |
+
if name.startswith("_"):
|
| 256 |
+
return False
|
| 257 |
+
|
| 258 |
+
if not callable(f):
|
| 259 |
+
return False
|
| 260 |
+
|
| 261 |
+
if "metadata" not in inspect.signature(f).parameters.keys():
|
| 262 |
+
return False
|
| 263 |
+
|
| 264 |
+
return True
|
| 265 |
+
|
| 266 |
+
def add_default_metadata_wrapper(f):
|
| 267 |
+
def call(*args, metadata=(), **kwargs):
|
| 268 |
+
metadata = list(metadata) + list(self.default_metadata)
|
| 269 |
+
return f(*args, **kwargs, metadata=metadata)
|
| 270 |
+
|
| 271 |
+
return call
|
| 272 |
+
|
| 273 |
+
for name, value in inspect.getmembers(cls):
|
| 274 |
+
if not keep(name, value):
|
| 275 |
+
continue
|
| 276 |
+
f = getattr(client, name)
|
| 277 |
+
f = add_default_metadata_wrapper(f)
|
| 278 |
+
setattr(client, name, f)
|
| 279 |
+
|
| 280 |
+
return client
|
| 281 |
+
|
| 282 |
+
def get_default_client(self, name):
|
| 283 |
+
name = name.lower()
|
| 284 |
+
if name == "operations":
|
| 285 |
+
return self.get_default_operations_client()
|
| 286 |
+
|
| 287 |
+
client = self.clients.get(name)
|
| 288 |
+
if client is None:
|
| 289 |
+
client = self.make_client(name)
|
| 290 |
+
self.clients[name] = client
|
| 291 |
+
return client
|
| 292 |
+
|
| 293 |
+
def get_default_operations_client(self) -> operations_v1.OperationsClient:
|
| 294 |
+
client = self.clients.get("operations", None)
|
| 295 |
+
if client is None:
|
| 296 |
+
model_client = self.get_default_client("Model")
|
| 297 |
+
client = model_client._transport.operations_client
|
| 298 |
+
self.clients["operations"] = client
|
| 299 |
+
return client
|
| 300 |
+
|
| 301 |
+
|
| 302 |
+
def configure(
|
| 303 |
+
*,
|
| 304 |
+
api_key: str | None = None,
|
| 305 |
+
credentials: ga_credentials.Credentials | dict | None = None,
|
| 306 |
+
# The user can pass a string to choose `rest` or `grpc` or 'grpc_asyncio'.
|
| 307 |
+
# Since the transport classes align with the client classes it wouldn't make
|
| 308 |
+
# sense to accept a `Transport` object here even though the client classes can.
|
| 309 |
+
# We could accept a dict since all the `Transport` classes take the same args,
|
| 310 |
+
# but that seems rare. Users that need it can just switch to the low level API.
|
| 311 |
+
transport: str | None = None,
|
| 312 |
+
client_options: client_options_lib.ClientOptions | dict | None = None,
|
| 313 |
+
client_info: gapic_v1.client_info.ClientInfo | None = None,
|
| 314 |
+
default_metadata: Sequence[tuple[str, str]] = (),
|
| 315 |
+
):
|
| 316 |
+
"""Captures default client configuration.
|
| 317 |
+
|
| 318 |
+
If no API key has been provided (either directly, or on `client_options`) and the
|
| 319 |
+
`GOOGLE_API_KEY` environment variable is set, it will be used as the API key.
|
| 320 |
+
|
| 321 |
+
Note: Not all arguments are detailed below. Refer to the `*ServiceClient` classes in
|
| 322 |
+
`google.ai.generativelanguage` for details on the other arguments.
|
| 323 |
+
|
| 324 |
+
Args:
|
| 325 |
+
transport: A string, one of: [`rest`, `grpc`, `grpc_asyncio`].
|
| 326 |
+
api_key: The API-Key to use when creating the default clients (each service uses
|
| 327 |
+
a separate client). This is a shortcut for `client_options={"api_key": api_key}`.
|
| 328 |
+
If omitted, and the `GOOGLE_API_KEY` environment variable is set, it will be
|
| 329 |
+
used.
|
| 330 |
+
default_metadata: Default (key, value) metadata pairs to send with every request.
|
| 331 |
+
when using `transport="rest"` these are sent as HTTP headers.
|
| 332 |
+
"""
|
| 333 |
+
return _client_manager.configure(
|
| 334 |
+
api_key=api_key,
|
| 335 |
+
credentials=credentials,
|
| 336 |
+
transport=transport,
|
| 337 |
+
client_options=client_options,
|
| 338 |
+
client_info=client_info,
|
| 339 |
+
default_metadata=default_metadata,
|
| 340 |
+
)
|
| 341 |
+
|
| 342 |
+
|
| 343 |
+
_client_manager = _ClientManager()
|
| 344 |
+
_client_manager.configure()
|
| 345 |
+
|
| 346 |
+
|
| 347 |
+
def get_default_cache_client() -> glm.CacheServiceClient:
|
| 348 |
+
return _client_manager.get_default_client("cache")
|
| 349 |
+
|
| 350 |
+
|
| 351 |
+
def get_default_file_client() -> glm.FilesServiceClient:
|
| 352 |
+
return _client_manager.get_default_client("file")
|
| 353 |
+
|
| 354 |
+
|
| 355 |
+
def get_default_file_async_client() -> glm.FilesServiceAsyncClient:
|
| 356 |
+
return _client_manager.get_default_client("file_async")
|
| 357 |
+
|
| 358 |
+
|
| 359 |
+
def get_default_generative_client() -> glm.GenerativeServiceClient:
|
| 360 |
+
return _client_manager.get_default_client("generative")
|
| 361 |
+
|
| 362 |
+
|
| 363 |
+
def get_default_generative_async_client() -> glm.GenerativeServiceAsyncClient:
|
| 364 |
+
return _client_manager.get_default_client("generative_async")
|
| 365 |
+
|
| 366 |
+
|
| 367 |
+
def get_default_operations_client() -> operations_v1.OperationsClient:
|
| 368 |
+
return _client_manager.get_default_client("operations")
|
| 369 |
+
|
| 370 |
+
|
| 371 |
+
def get_default_model_client() -> glm.ModelServiceAsyncClient:
|
| 372 |
+
return _client_manager.get_default_client("model")
|
| 373 |
+
|
| 374 |
+
|
| 375 |
+
def get_default_retriever_client() -> glm.RetrieverClient:
|
| 376 |
+
return _client_manager.get_default_client("retriever")
|
| 377 |
+
|
| 378 |
+
|
| 379 |
+
def get_default_retriever_async_client() -> glm.RetrieverAsyncClient:
|
| 380 |
+
return _client_manager.get_default_client("retriever_async")
|
| 381 |
+
|
| 382 |
+
|
| 383 |
+
def get_default_permission_client() -> glm.PermissionServiceClient:
|
| 384 |
+
return _client_manager.get_default_client("permission")
|
| 385 |
+
|
| 386 |
+
|
| 387 |
+
def get_default_permission_async_client() -> glm.PermissionServiceAsyncClient:
|
| 388 |
+
return _client_manager.get_default_client("permission_async")
|
.venv/lib/python3.11/site-packages/google/generativeai/discuss.py
ADDED
|
@@ -0,0 +1,590 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Copyright 2023 Google LLC
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
from __future__ import annotations
|
| 16 |
+
|
| 17 |
+
import dataclasses
|
| 18 |
+
import sys
|
| 19 |
+
import textwrap
|
| 20 |
+
|
| 21 |
+
from typing import Any, Iterable, List, Optional, Union
|
| 22 |
+
|
| 23 |
+
import google.ai.generativelanguage as glm
|
| 24 |
+
|
| 25 |
+
from google.generativeai.client import get_default_discuss_client
|
| 26 |
+
from google.generativeai.client import get_default_discuss_async_client
|
| 27 |
+
from google.generativeai import string_utils
|
| 28 |
+
from google.generativeai.types import discuss_types
|
| 29 |
+
from google.generativeai.types import model_types
|
| 30 |
+
from google.generativeai.types import safety_types
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def _make_message(content: discuss_types.MessageOptions) -> glm.Message:
|
| 34 |
+
"""Creates a `glm.Message` object from the provided content."""
|
| 35 |
+
if isinstance(content, glm.Message):
|
| 36 |
+
return content
|
| 37 |
+
if isinstance(content, str):
|
| 38 |
+
return glm.Message(content=content)
|
| 39 |
+
else:
|
| 40 |
+
return glm.Message(content)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def _make_messages(
|
| 44 |
+
messages: discuss_types.MessagesOptions,
|
| 45 |
+
) -> List[glm.Message]:
|
| 46 |
+
"""
|
| 47 |
+
Creates a list of `glm.Message` objects from the provided messages.
|
| 48 |
+
|
| 49 |
+
This function takes a variety of message content inputs, such as strings, dictionaries,
|
| 50 |
+
or `glm.Message` objects, and creates a list of `glm.Message` objects. It ensures that
|
| 51 |
+
the authors of the messages alternate appropriately. If authors are not provided,
|
| 52 |
+
default authors are assigned based on their position in the list.
|
| 53 |
+
|
| 54 |
+
Args:
|
| 55 |
+
messages: The messages to convert.
|
| 56 |
+
|
| 57 |
+
Returns:
|
| 58 |
+
A list of `glm.Message` objects with alternating authors.
|
| 59 |
+
"""
|
| 60 |
+
if isinstance(messages, (str, dict, glm.Message)):
|
| 61 |
+
messages = [_make_message(messages)]
|
| 62 |
+
else:
|
| 63 |
+
messages = [_make_message(message) for message in messages]
|
| 64 |
+
|
| 65 |
+
even_authors = set(msg.author for msg in messages[::2] if msg.author)
|
| 66 |
+
if not even_authors:
|
| 67 |
+
even_author = "0"
|
| 68 |
+
elif len(even_authors) == 1:
|
| 69 |
+
even_author = even_authors.pop()
|
| 70 |
+
else:
|
| 71 |
+
raise discuss_types.AuthorError("Authors are not strictly alternating")
|
| 72 |
+
|
| 73 |
+
odd_authors = set(msg.author for msg in messages[1::2] if msg.author)
|
| 74 |
+
if not odd_authors:
|
| 75 |
+
odd_author = "1"
|
| 76 |
+
elif len(odd_authors) == 1:
|
| 77 |
+
odd_author = odd_authors.pop()
|
| 78 |
+
else:
|
| 79 |
+
raise discuss_types.AuthorError("Authors are not strictly alternating")
|
| 80 |
+
|
| 81 |
+
if all(msg.author for msg in messages):
|
| 82 |
+
return messages
|
| 83 |
+
|
| 84 |
+
authors = [even_author, odd_author]
|
| 85 |
+
for i, msg in enumerate(messages):
|
| 86 |
+
msg.author = authors[i % 2]
|
| 87 |
+
|
| 88 |
+
return messages
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def _make_example(item: discuss_types.ExampleOptions) -> glm.Example:
|
| 92 |
+
"""Creates a `glm.Example` object from the provided item."""
|
| 93 |
+
if isinstance(item, glm.Example):
|
| 94 |
+
return item
|
| 95 |
+
|
| 96 |
+
if isinstance(item, dict):
|
| 97 |
+
item = item.copy()
|
| 98 |
+
item["input"] = _make_message(item["input"])
|
| 99 |
+
item["output"] = _make_message(item["output"])
|
| 100 |
+
return glm.Example(item)
|
| 101 |
+
|
| 102 |
+
if isinstance(item, Iterable):
|
| 103 |
+
input, output = list(item)
|
| 104 |
+
return glm.Example(input=_make_message(input), output=_make_message(output))
|
| 105 |
+
|
| 106 |
+
# try anyway
|
| 107 |
+
return glm.Example(item)
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def _make_examples_from_flat(
|
| 111 |
+
examples: List[discuss_types.MessageOptions],
|
| 112 |
+
) -> List[glm.Example]:
|
| 113 |
+
"""
|
| 114 |
+
Creates a list of `glm.Example` objects from a list of message options.
|
| 115 |
+
|
| 116 |
+
This function takes a list of `discuss_types.MessageOptions` and pairs them into
|
| 117 |
+
`glm.Example` objects. The input examples must be in pairs to create valid examples.
|
| 118 |
+
|
| 119 |
+
Args:
|
| 120 |
+
examples: The list of `discuss_types.MessageOptions`.
|
| 121 |
+
|
| 122 |
+
Returns:
|
| 123 |
+
A list of `glm.Example objects` created by pairing up the provided messages.
|
| 124 |
+
|
| 125 |
+
Raises:
|
| 126 |
+
ValueError: If the provided list of examples is not of even length.
|
| 127 |
+
"""
|
| 128 |
+
if len(examples) % 2 != 0:
|
| 129 |
+
raise ValueError(
|
| 130 |
+
textwrap.dedent(
|
| 131 |
+
f"""\
|
| 132 |
+
You must pass `Primer` objects, pairs of messages, or an *even* number of messages, got:
|
| 133 |
+
{len(examples)} messages"""
|
| 134 |
+
)
|
| 135 |
+
)
|
| 136 |
+
result = []
|
| 137 |
+
pair = []
|
| 138 |
+
for n, item in enumerate(examples):
|
| 139 |
+
msg = _make_message(item)
|
| 140 |
+
pair.append(msg)
|
| 141 |
+
if n % 2 == 0:
|
| 142 |
+
continue
|
| 143 |
+
primer = glm.Example(
|
| 144 |
+
input=pair[0],
|
| 145 |
+
output=pair[1],
|
| 146 |
+
)
|
| 147 |
+
result.append(primer)
|
| 148 |
+
pair = []
|
| 149 |
+
return result
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def _make_examples(
|
| 153 |
+
examples: discuss_types.ExamplesOptions,
|
| 154 |
+
) -> List[glm.Example]:
|
| 155 |
+
"""
|
| 156 |
+
Creates a list of `glm.Example` objects from the provided examples.
|
| 157 |
+
|
| 158 |
+
This function takes various types of example content inputs and creates a list
|
| 159 |
+
of `glm.Example` objects. It handles the conversion of different input types and ensures
|
| 160 |
+
the appropriate structure for creating valid examples.
|
| 161 |
+
|
| 162 |
+
Args:
|
| 163 |
+
examples: The examples to convert.
|
| 164 |
+
|
| 165 |
+
Returns:
|
| 166 |
+
A list of `glm.Example` objects created from the provided examples.
|
| 167 |
+
"""
|
| 168 |
+
if isinstance(examples, glm.Example):
|
| 169 |
+
return [examples]
|
| 170 |
+
|
| 171 |
+
if isinstance(examples, dict):
|
| 172 |
+
return [_make_example(examples)]
|
| 173 |
+
|
| 174 |
+
examples = list(examples)
|
| 175 |
+
|
| 176 |
+
if not examples:
|
| 177 |
+
return examples
|
| 178 |
+
|
| 179 |
+
first = examples[0]
|
| 180 |
+
|
| 181 |
+
if isinstance(first, dict):
|
| 182 |
+
if "content" in first:
|
| 183 |
+
# These are `Messages`
|
| 184 |
+
return _make_examples_from_flat(examples)
|
| 185 |
+
else:
|
| 186 |
+
if not ("input" in first and "output" in first):
|
| 187 |
+
raise TypeError(
|
| 188 |
+
"To create an `Example` from a dict you must supply both `input` and an `output` keys"
|
| 189 |
+
)
|
| 190 |
+
else:
|
| 191 |
+
if isinstance(first, discuss_types.MESSAGE_OPTIONS):
|
| 192 |
+
return _make_examples_from_flat(examples)
|
| 193 |
+
|
| 194 |
+
result = []
|
| 195 |
+
for item in examples:
|
| 196 |
+
result.append(_make_example(item))
|
| 197 |
+
return result
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
def _make_message_prompt_dict(
|
| 201 |
+
prompt: discuss_types.MessagePromptOptions = None,
|
| 202 |
+
*,
|
| 203 |
+
context: str | None = None,
|
| 204 |
+
examples: discuss_types.ExamplesOptions | None = None,
|
| 205 |
+
messages: discuss_types.MessagesOptions | None = None,
|
| 206 |
+
) -> glm.MessagePrompt:
|
| 207 |
+
"""
|
| 208 |
+
Creates a `glm.MessagePrompt` object from the provided prompt components.
|
| 209 |
+
|
| 210 |
+
This function constructs a `glm.MessagePrompt` object using the provided `context`, `examples`,
|
| 211 |
+
or `messages`. It ensures the proper structure and handling of the input components.
|
| 212 |
+
|
| 213 |
+
Either pass a `prompt` or it's component `context`, `examples`, `messages`.
|
| 214 |
+
|
| 215 |
+
Args:
|
| 216 |
+
prompt: The complete prompt components.
|
| 217 |
+
context: The context for the prompt.
|
| 218 |
+
examples: The examples for the prompt.
|
| 219 |
+
messages: The messages for the prompt.
|
| 220 |
+
|
| 221 |
+
Returns:
|
| 222 |
+
A `glm.MessagePrompt` object created from the provided prompt components.
|
| 223 |
+
"""
|
| 224 |
+
if prompt is None:
|
| 225 |
+
prompt = dict(
|
| 226 |
+
context=context,
|
| 227 |
+
examples=examples,
|
| 228 |
+
messages=messages,
|
| 229 |
+
)
|
| 230 |
+
else:
|
| 231 |
+
flat_prompt = (context is not None) or (examples is not None) or (messages is not None)
|
| 232 |
+
if flat_prompt:
|
| 233 |
+
raise ValueError(
|
| 234 |
+
"You can't set `prompt`, and its fields `(context, examples, messages)`"
|
| 235 |
+
" at the same time"
|
| 236 |
+
)
|
| 237 |
+
if isinstance(prompt, glm.MessagePrompt):
|
| 238 |
+
return prompt
|
| 239 |
+
elif isinstance(prompt, dict): # Always check dict before Iterable.
|
| 240 |
+
pass
|
| 241 |
+
else:
|
| 242 |
+
prompt = {"messages": prompt}
|
| 243 |
+
|
| 244 |
+
keys = set(prompt.keys())
|
| 245 |
+
if not keys.issubset(discuss_types.MESSAGE_PROMPT_KEYS):
|
| 246 |
+
raise KeyError(
|
| 247 |
+
f"Found extra entries in the prompt dictionary: {keys - discuss_types.MESSAGE_PROMPT_KEYS}"
|
| 248 |
+
)
|
| 249 |
+
|
| 250 |
+
examples = prompt.get("examples", None)
|
| 251 |
+
if examples is not None:
|
| 252 |
+
prompt["examples"] = _make_examples(examples)
|
| 253 |
+
messages = prompt.get("messages", None)
|
| 254 |
+
if messages is not None:
|
| 255 |
+
prompt["messages"] = _make_messages(messages)
|
| 256 |
+
|
| 257 |
+
prompt = {k: v for k, v in prompt.items() if v is not None}
|
| 258 |
+
return prompt
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
def _make_message_prompt(
|
| 262 |
+
prompt: discuss_types.MessagePromptOptions = None,
|
| 263 |
+
*,
|
| 264 |
+
context: str | None = None,
|
| 265 |
+
examples: discuss_types.ExamplesOptions | None = None,
|
| 266 |
+
messages: discuss_types.MessagesOptions | None = None,
|
| 267 |
+
) -> glm.MessagePrompt:
|
| 268 |
+
"""Creates a `glm.MessagePrompt` object from the provided prompt components."""
|
| 269 |
+
prompt = _make_message_prompt_dict(
|
| 270 |
+
prompt=prompt, context=context, examples=examples, messages=messages
|
| 271 |
+
)
|
| 272 |
+
return glm.MessagePrompt(prompt)
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
def _make_generate_message_request(
|
| 276 |
+
*,
|
| 277 |
+
model: model_types.AnyModelNameOptions | None,
|
| 278 |
+
context: str | None = None,
|
| 279 |
+
examples: discuss_types.ExamplesOptions | None = None,
|
| 280 |
+
messages: discuss_types.MessagesOptions | None = None,
|
| 281 |
+
temperature: float | None = None,
|
| 282 |
+
candidate_count: int | None = None,
|
| 283 |
+
top_p: float | None = None,
|
| 284 |
+
top_k: float | None = None,
|
| 285 |
+
prompt: discuss_types.MessagePromptOptions | None = None,
|
| 286 |
+
) -> glm.GenerateMessageRequest:
|
| 287 |
+
"""Creates a `glm.GenerateMessageRequest` object for generating messages."""
|
| 288 |
+
model = model_types.make_model_name(model)
|
| 289 |
+
|
| 290 |
+
prompt = _make_message_prompt(
|
| 291 |
+
prompt=prompt, context=context, examples=examples, messages=messages
|
| 292 |
+
)
|
| 293 |
+
|
| 294 |
+
return glm.GenerateMessageRequest(
|
| 295 |
+
model=model,
|
| 296 |
+
prompt=prompt,
|
| 297 |
+
temperature=temperature,
|
| 298 |
+
top_p=top_p,
|
| 299 |
+
top_k=top_k,
|
| 300 |
+
candidate_count=candidate_count,
|
| 301 |
+
)
|
| 302 |
+
|
| 303 |
+
|
| 304 |
+
DEFAULT_DISCUSS_MODEL = "models/chat-bison-001"
|
| 305 |
+
|
| 306 |
+
|
| 307 |
+
def chat(
|
| 308 |
+
*,
|
| 309 |
+
model: model_types.AnyModelNameOptions | None = "models/chat-bison-001",
|
| 310 |
+
context: str | None = None,
|
| 311 |
+
examples: discuss_types.ExamplesOptions | None = None,
|
| 312 |
+
messages: discuss_types.MessagesOptions | None = None,
|
| 313 |
+
temperature: float | None = None,
|
| 314 |
+
candidate_count: int | None = None,
|
| 315 |
+
top_p: float | None = None,
|
| 316 |
+
top_k: float | None = None,
|
| 317 |
+
prompt: discuss_types.MessagePromptOptions | None = None,
|
| 318 |
+
client: glm.DiscussServiceClient | None = None,
|
| 319 |
+
request_options: dict[str, Any] | None = None,
|
| 320 |
+
) -> discuss_types.ChatResponse:
|
| 321 |
+
"""Calls the API and returns a `types.ChatResponse` containing the response.
|
| 322 |
+
|
| 323 |
+
Args:
|
| 324 |
+
model: Which model to call, as a string or a `types.Model`.
|
| 325 |
+
context: Text that should be provided to the model first, to ground the response.
|
| 326 |
+
|
| 327 |
+
If not empty, this `context` will be given to the model first before the
|
| 328 |
+
`examples` and `messages`.
|
| 329 |
+
|
| 330 |
+
This field can be a description of your prompt to the model to help provide
|
| 331 |
+
context and guide the responses.
|
| 332 |
+
|
| 333 |
+
Examples:
|
| 334 |
+
|
| 335 |
+
* "Translate the phrase from English to French."
|
| 336 |
+
* "Given a statement, classify the sentiment as happy, sad or neutral."
|
| 337 |
+
|
| 338 |
+
Anything included in this field will take precedence over history in `messages`
|
| 339 |
+
if the total input size exceeds the model's `Model.input_token_limit`.
|
| 340 |
+
examples: Examples of what the model should generate.
|
| 341 |
+
|
| 342 |
+
This includes both the user input and the response that the model should
|
| 343 |
+
emulate.
|
| 344 |
+
|
| 345 |
+
These `examples` are treated identically to conversation messages except
|
| 346 |
+
that they take precedence over the history in `messages`:
|
| 347 |
+
If the total input size exceeds the model's `input_token_limit` the input
|
| 348 |
+
will be truncated. Items will be dropped from `messages` before `examples`
|
| 349 |
+
messages: A snapshot of the conversation history sorted chronologically.
|
| 350 |
+
|
| 351 |
+
Turns alternate between two authors.
|
| 352 |
+
|
| 353 |
+
If the total input size exceeds the model's `input_token_limit` the input
|
| 354 |
+
will be truncated: The oldest items will be dropped from `messages`.
|
| 355 |
+
temperature: Controls the randomness of the output. Must be positive.
|
| 356 |
+
|
| 357 |
+
Typical values are in the range: `[0.0,1.0]`. Higher values produce a
|
| 358 |
+
more random and varied response. A temperature of zero will be deterministic.
|
| 359 |
+
candidate_count: The **maximum** number of generated response messages to return.
|
| 360 |
+
|
| 361 |
+
This value must be between `[1, 8]`, inclusive. If unset, this
|
| 362 |
+
will default to `1`.
|
| 363 |
+
|
| 364 |
+
Note: Only unique candidates are returned. Higher temperatures are more
|
| 365 |
+
likely to produce unique candidates. Setting `temperature=0.0` will always
|
| 366 |
+
return 1 candidate regardless of the `candidate_count`.
|
| 367 |
+
top_k: The API uses combined [nucleus](https://arxiv.org/abs/1904.09751) and
|
| 368 |
+
top-k sampling.
|
| 369 |
+
|
| 370 |
+
`top_k` sets the maximum number of tokens to sample from on each step.
|
| 371 |
+
top_p: The API uses combined [nucleus](https://arxiv.org/abs/1904.09751) and
|
| 372 |
+
top-k sampling.
|
| 373 |
+
|
| 374 |
+
`top_p` configures the nucleus sampling. It sets the maximum cumulative
|
| 375 |
+
probability of tokens to sample from.
|
| 376 |
+
|
| 377 |
+
For example, if the sorted probabilities are
|
| 378 |
+
`[0.5, 0.2, 0.1, 0.1, 0.05, 0.05]` a `top_p` of `0.8` will sample
|
| 379 |
+
as `[0.625, 0.25, 0.125, 0, 0, 0]`.
|
| 380 |
+
|
| 381 |
+
Typical values are in the `[0.9, 1.0]` range.
|
| 382 |
+
prompt: You may pass a `types.MessagePromptOptions` **instead** of a
|
| 383 |
+
setting `context`/`examples`/`messages`, but not both.
|
| 384 |
+
client: If you're not relying on the default client, you pass a
|
| 385 |
+
`glm.DiscussServiceClient` instead.
|
| 386 |
+
request_options: Options for the request.
|
| 387 |
+
|
| 388 |
+
Returns:
|
| 389 |
+
A `types.ChatResponse` containing the model's reply.
|
| 390 |
+
"""
|
| 391 |
+
request = _make_generate_message_request(
|
| 392 |
+
model=model,
|
| 393 |
+
context=context,
|
| 394 |
+
examples=examples,
|
| 395 |
+
messages=messages,
|
| 396 |
+
temperature=temperature,
|
| 397 |
+
candidate_count=candidate_count,
|
| 398 |
+
top_p=top_p,
|
| 399 |
+
top_k=top_k,
|
| 400 |
+
prompt=prompt,
|
| 401 |
+
)
|
| 402 |
+
|
| 403 |
+
return _generate_response(client=client, request=request, request_options=request_options)
|
| 404 |
+
|
| 405 |
+
|
| 406 |
+
@string_utils.set_doc(chat.__doc__)
|
| 407 |
+
async def chat_async(
|
| 408 |
+
*,
|
| 409 |
+
model: model_types.AnyModelNameOptions | None = "models/chat-bison-001",
|
| 410 |
+
context: str | None = None,
|
| 411 |
+
examples: discuss_types.ExamplesOptions | None = None,
|
| 412 |
+
messages: discuss_types.MessagesOptions | None = None,
|
| 413 |
+
temperature: float | None = None,
|
| 414 |
+
candidate_count: int | None = None,
|
| 415 |
+
top_p: float | None = None,
|
| 416 |
+
top_k: float | None = None,
|
| 417 |
+
prompt: discuss_types.MessagePromptOptions | None = None,
|
| 418 |
+
client: glm.DiscussServiceAsyncClient | None = None,
|
| 419 |
+
request_options: dict[str, Any] | None = None,
|
| 420 |
+
) -> discuss_types.ChatResponse:
|
| 421 |
+
request = _make_generate_message_request(
|
| 422 |
+
model=model,
|
| 423 |
+
context=context,
|
| 424 |
+
examples=examples,
|
| 425 |
+
messages=messages,
|
| 426 |
+
temperature=temperature,
|
| 427 |
+
candidate_count=candidate_count,
|
| 428 |
+
top_p=top_p,
|
| 429 |
+
top_k=top_k,
|
| 430 |
+
prompt=prompt,
|
| 431 |
+
)
|
| 432 |
+
|
| 433 |
+
return await _generate_response_async(
|
| 434 |
+
client=client, request=request, request_options=request_options
|
| 435 |
+
)
|
| 436 |
+
|
| 437 |
+
|
| 438 |
+
if (sys.version_info.major, sys.version_info.minor) >= (3, 10):
|
| 439 |
+
DATACLASS_KWARGS = {"kw_only": True}
|
| 440 |
+
else:
|
| 441 |
+
DATACLASS_KWARGS = {}
|
| 442 |
+
|
| 443 |
+
|
| 444 |
+
@string_utils.prettyprint
|
| 445 |
+
@string_utils.set_doc(discuss_types.ChatResponse.__doc__)
|
| 446 |
+
@dataclasses.dataclass(**DATACLASS_KWARGS, init=False)
|
| 447 |
+
class ChatResponse(discuss_types.ChatResponse):
|
| 448 |
+
_client: glm.DiscussServiceClient | None = dataclasses.field(default=lambda: None, repr=False)
|
| 449 |
+
|
| 450 |
+
def __init__(self, **kwargs):
|
| 451 |
+
for key, value in kwargs.items():
|
| 452 |
+
setattr(self, key, value)
|
| 453 |
+
|
| 454 |
+
@property
|
| 455 |
+
@string_utils.set_doc(discuss_types.ChatResponse.last.__doc__)
|
| 456 |
+
def last(self) -> str | None:
|
| 457 |
+
if self.messages[-1]:
|
| 458 |
+
return self.messages[-1]["content"]
|
| 459 |
+
else:
|
| 460 |
+
return None
|
| 461 |
+
|
| 462 |
+
@last.setter
|
| 463 |
+
def last(self, message: discuss_types.MessageOptions):
|
| 464 |
+
message = _make_message(message)
|
| 465 |
+
message = type(message).to_dict(message)
|
| 466 |
+
self.messages[-1] = message
|
| 467 |
+
|
| 468 |
+
@string_utils.set_doc(discuss_types.ChatResponse.reply.__doc__)
|
| 469 |
+
def reply(
|
| 470 |
+
self,
|
| 471 |
+
message: discuss_types.MessageOptions,
|
| 472 |
+
request_options: dict[str, Any] | None = None,
|
| 473 |
+
) -> discuss_types.ChatResponse:
|
| 474 |
+
if isinstance(self._client, glm.DiscussServiceAsyncClient):
|
| 475 |
+
raise TypeError(f"reply can't be called on an async client, use reply_async instead.")
|
| 476 |
+
if self.last is None:
|
| 477 |
+
raise ValueError(
|
| 478 |
+
"The last response from the model did not return any candidates.\n"
|
| 479 |
+
"Check the `.filters` attribute to see why the responses were filtered:\n"
|
| 480 |
+
f"{self.filters}"
|
| 481 |
+
)
|
| 482 |
+
|
| 483 |
+
request = self.to_dict()
|
| 484 |
+
request.pop("candidates")
|
| 485 |
+
request.pop("filters", None)
|
| 486 |
+
request["messages"] = list(request["messages"])
|
| 487 |
+
request["messages"].append(_make_message(message))
|
| 488 |
+
request = _make_generate_message_request(**request)
|
| 489 |
+
return _generate_response(
|
| 490 |
+
request=request, client=self._client, request_options=request_options
|
| 491 |
+
)
|
| 492 |
+
|
| 493 |
+
@string_utils.set_doc(discuss_types.ChatResponse.reply.__doc__)
|
| 494 |
+
async def reply_async(
|
| 495 |
+
self, message: discuss_types.MessageOptions
|
| 496 |
+
) -> discuss_types.ChatResponse:
|
| 497 |
+
if isinstance(self._client, glm.DiscussServiceClient):
|
| 498 |
+
raise TypeError(
|
| 499 |
+
f"reply_async can't be called on a non-async client, use reply instead."
|
| 500 |
+
)
|
| 501 |
+
request = self.to_dict()
|
| 502 |
+
request.pop("candidates")
|
| 503 |
+
request.pop("filters", None)
|
| 504 |
+
request["messages"] = list(request["messages"])
|
| 505 |
+
request["messages"].append(_make_message(message))
|
| 506 |
+
request = _make_generate_message_request(**request)
|
| 507 |
+
return await _generate_response_async(request=request, client=self._client)
|
| 508 |
+
|
| 509 |
+
|
| 510 |
+
def _build_chat_response(
|
| 511 |
+
request: glm.GenerateMessageRequest,
|
| 512 |
+
response: glm.GenerateMessageResponse,
|
| 513 |
+
client: glm.DiscussServiceClient | glm.DiscussServiceAsyncClient,
|
| 514 |
+
) -> ChatResponse:
|
| 515 |
+
request = type(request).to_dict(request)
|
| 516 |
+
prompt = request.pop("prompt")
|
| 517 |
+
request["examples"] = prompt["examples"]
|
| 518 |
+
request["context"] = prompt["context"]
|
| 519 |
+
request["messages"] = prompt["messages"]
|
| 520 |
+
|
| 521 |
+
response = type(response).to_dict(response)
|
| 522 |
+
response.pop("messages")
|
| 523 |
+
|
| 524 |
+
response["filters"] = safety_types.convert_filters_to_enums(response["filters"])
|
| 525 |
+
|
| 526 |
+
if response["candidates"]:
|
| 527 |
+
last = response["candidates"][0]
|
| 528 |
+
else:
|
| 529 |
+
last = None
|
| 530 |
+
request["messages"].append(last)
|
| 531 |
+
request.setdefault("temperature", None)
|
| 532 |
+
request.setdefault("candidate_count", None)
|
| 533 |
+
|
| 534 |
+
return ChatResponse(_client=client, **response, **request) # pytype: disable=missing-parameter
|
| 535 |
+
|
| 536 |
+
|
| 537 |
+
def _generate_response(
|
| 538 |
+
request: glm.GenerateMessageRequest,
|
| 539 |
+
client: glm.DiscussServiceClient | None = None,
|
| 540 |
+
request_options: dict[str, Any] | None = None,
|
| 541 |
+
) -> ChatResponse:
|
| 542 |
+
if request_options is None:
|
| 543 |
+
request_options = {}
|
| 544 |
+
|
| 545 |
+
if client is None:
|
| 546 |
+
client = get_default_discuss_client()
|
| 547 |
+
|
| 548 |
+
response = client.generate_message(request, **request_options)
|
| 549 |
+
|
| 550 |
+
return _build_chat_response(request, response, client)
|
| 551 |
+
|
| 552 |
+
|
| 553 |
+
async def _generate_response_async(
|
| 554 |
+
request: glm.GenerateMessageRequest,
|
| 555 |
+
client: glm.DiscussServiceAsyncClient | None = None,
|
| 556 |
+
request_options: dict[str, Any] | None = None,
|
| 557 |
+
) -> ChatResponse:
|
| 558 |
+
if request_options is None:
|
| 559 |
+
request_options = {}
|
| 560 |
+
|
| 561 |
+
if client is None:
|
| 562 |
+
client = get_default_discuss_async_client()
|
| 563 |
+
|
| 564 |
+
response = await client.generate_message(request, **request_options)
|
| 565 |
+
|
| 566 |
+
return _build_chat_response(request, response, client)
|
| 567 |
+
|
| 568 |
+
|
| 569 |
+
def count_message_tokens(
|
| 570 |
+
*,
|
| 571 |
+
prompt: discuss_types.MessagePromptOptions = None,
|
| 572 |
+
context: str | None = None,
|
| 573 |
+
examples: discuss_types.ExamplesOptions | None = None,
|
| 574 |
+
messages: discuss_types.MessagesOptions | None = None,
|
| 575 |
+
model: model_types.AnyModelNameOptions = DEFAULT_DISCUSS_MODEL,
|
| 576 |
+
client: glm.DiscussServiceAsyncClient | None = None,
|
| 577 |
+
request_options: dict[str, Any] | None = None,
|
| 578 |
+
) -> discuss_types.TokenCount:
|
| 579 |
+
model = model_types.make_model_name(model)
|
| 580 |
+
prompt = _make_message_prompt(prompt, context=context, examples=examples, messages=messages)
|
| 581 |
+
|
| 582 |
+
if request_options is None:
|
| 583 |
+
request_options = {}
|
| 584 |
+
|
| 585 |
+
if client is None:
|
| 586 |
+
client = get_default_discuss_client()
|
| 587 |
+
|
| 588 |
+
result = client.count_message_tokens(model=model, prompt=prompt, **request_options)
|
| 589 |
+
|
| 590 |
+
return type(result).to_dict(result)
|
.venv/lib/python3.11/site-packages/google/generativeai/embedding.py
ADDED
|
@@ -0,0 +1,312 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Copyright 2023 Google LLC
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
from __future__ import annotations
|
| 16 |
+
|
| 17 |
+
import itertools
|
| 18 |
+
from typing import Any, Iterable, overload, TypeVar, Union, Mapping
|
| 19 |
+
|
| 20 |
+
import google.ai.generativelanguage as glm
|
| 21 |
+
from google.generativeai import protos
|
| 22 |
+
|
| 23 |
+
from google.generativeai.client import get_default_generative_client
|
| 24 |
+
from google.generativeai.client import get_default_generative_async_client
|
| 25 |
+
|
| 26 |
+
from google.generativeai.types import helper_types
|
| 27 |
+
from google.generativeai.types import model_types
|
| 28 |
+
from google.generativeai.types import text_types
|
| 29 |
+
from google.generativeai.types import content_types
|
| 30 |
+
|
| 31 |
+
DEFAULT_EMB_MODEL = "models/embedding-001"
|
| 32 |
+
EMBEDDING_MAX_BATCH_SIZE = 100
|
| 33 |
+
|
| 34 |
+
EmbeddingTaskType = protos.TaskType
|
| 35 |
+
|
| 36 |
+
EmbeddingTaskTypeOptions = Union[int, str, EmbeddingTaskType]
|
| 37 |
+
|
| 38 |
+
_EMBEDDING_TASK_TYPE: dict[EmbeddingTaskTypeOptions, EmbeddingTaskType] = {
|
| 39 |
+
EmbeddingTaskType.TASK_TYPE_UNSPECIFIED: EmbeddingTaskType.TASK_TYPE_UNSPECIFIED,
|
| 40 |
+
0: EmbeddingTaskType.TASK_TYPE_UNSPECIFIED,
|
| 41 |
+
"task_type_unspecified": EmbeddingTaskType.TASK_TYPE_UNSPECIFIED,
|
| 42 |
+
"unspecified": EmbeddingTaskType.TASK_TYPE_UNSPECIFIED,
|
| 43 |
+
EmbeddingTaskType.RETRIEVAL_QUERY: EmbeddingTaskType.RETRIEVAL_QUERY,
|
| 44 |
+
1: EmbeddingTaskType.RETRIEVAL_QUERY,
|
| 45 |
+
"retrieval_query": EmbeddingTaskType.RETRIEVAL_QUERY,
|
| 46 |
+
"query": EmbeddingTaskType.RETRIEVAL_QUERY,
|
| 47 |
+
EmbeddingTaskType.RETRIEVAL_DOCUMENT: EmbeddingTaskType.RETRIEVAL_DOCUMENT,
|
| 48 |
+
2: EmbeddingTaskType.RETRIEVAL_DOCUMENT,
|
| 49 |
+
"retrieval_document": EmbeddingTaskType.RETRIEVAL_DOCUMENT,
|
| 50 |
+
"document": EmbeddingTaskType.RETRIEVAL_DOCUMENT,
|
| 51 |
+
EmbeddingTaskType.SEMANTIC_SIMILARITY: EmbeddingTaskType.SEMANTIC_SIMILARITY,
|
| 52 |
+
3: EmbeddingTaskType.SEMANTIC_SIMILARITY,
|
| 53 |
+
"semantic_similarity": EmbeddingTaskType.SEMANTIC_SIMILARITY,
|
| 54 |
+
"similarity": EmbeddingTaskType.SEMANTIC_SIMILARITY,
|
| 55 |
+
EmbeddingTaskType.CLASSIFICATION: EmbeddingTaskType.CLASSIFICATION,
|
| 56 |
+
4: EmbeddingTaskType.CLASSIFICATION,
|
| 57 |
+
"classification": EmbeddingTaskType.CLASSIFICATION,
|
| 58 |
+
EmbeddingTaskType.CLUSTERING: EmbeddingTaskType.CLUSTERING,
|
| 59 |
+
5: EmbeddingTaskType.CLUSTERING,
|
| 60 |
+
"clustering": EmbeddingTaskType.CLUSTERING,
|
| 61 |
+
6: EmbeddingTaskType.QUESTION_ANSWERING,
|
| 62 |
+
"question_answering": EmbeddingTaskType.QUESTION_ANSWERING,
|
| 63 |
+
"qa": EmbeddingTaskType.QUESTION_ANSWERING,
|
| 64 |
+
EmbeddingTaskType.QUESTION_ANSWERING: EmbeddingTaskType.QUESTION_ANSWERING,
|
| 65 |
+
7: EmbeddingTaskType.FACT_VERIFICATION,
|
| 66 |
+
"fact_verification": EmbeddingTaskType.FACT_VERIFICATION,
|
| 67 |
+
"verification": EmbeddingTaskType.FACT_VERIFICATION,
|
| 68 |
+
EmbeddingTaskType.FACT_VERIFICATION: EmbeddingTaskType.FACT_VERIFICATION,
|
| 69 |
+
}
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def to_task_type(x: EmbeddingTaskTypeOptions) -> EmbeddingTaskType:
|
| 73 |
+
if isinstance(x, str):
|
| 74 |
+
x = x.lower()
|
| 75 |
+
return _EMBEDDING_TASK_TYPE[x]
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
try:
|
| 79 |
+
# python 3.12+
|
| 80 |
+
_batched = itertools.batched # type: ignore
|
| 81 |
+
except AttributeError:
|
| 82 |
+
T = TypeVar("T")
|
| 83 |
+
|
| 84 |
+
def _batched(iterable: Iterable[T], n: int) -> Iterable[list[T]]:
|
| 85 |
+
if n < 1:
|
| 86 |
+
raise ValueError(
|
| 87 |
+
f"Invalid input: The batch size 'n' must be a positive integer. You entered: {n}. Please enter a number greater than 0."
|
| 88 |
+
)
|
| 89 |
+
batch = []
|
| 90 |
+
for item in iterable:
|
| 91 |
+
batch.append(item)
|
| 92 |
+
if len(batch) == n:
|
| 93 |
+
yield batch
|
| 94 |
+
batch = []
|
| 95 |
+
|
| 96 |
+
if batch:
|
| 97 |
+
yield batch
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
@overload
|
| 101 |
+
def embed_content(
|
| 102 |
+
model: model_types.BaseModelNameOptions,
|
| 103 |
+
content: content_types.ContentType,
|
| 104 |
+
task_type: EmbeddingTaskTypeOptions | None = None,
|
| 105 |
+
title: str | None = None,
|
| 106 |
+
output_dimensionality: int | None = None,
|
| 107 |
+
client: glm.GenerativeServiceClient | None = None,
|
| 108 |
+
request_options: helper_types.RequestOptionsType | None = None,
|
| 109 |
+
) -> text_types.EmbeddingDict: ...
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
@overload
|
| 113 |
+
def embed_content(
|
| 114 |
+
model: model_types.BaseModelNameOptions,
|
| 115 |
+
content: Iterable[content_types.ContentType],
|
| 116 |
+
task_type: EmbeddingTaskTypeOptions | None = None,
|
| 117 |
+
title: str | None = None,
|
| 118 |
+
output_dimensionality: int | None = None,
|
| 119 |
+
client: glm.GenerativeServiceClient | None = None,
|
| 120 |
+
request_options: helper_types.RequestOptionsType | None = None,
|
| 121 |
+
) -> text_types.BatchEmbeddingDict: ...
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def embed_content(
|
| 125 |
+
model: model_types.BaseModelNameOptions,
|
| 126 |
+
content: content_types.ContentType | Iterable[content_types.ContentType],
|
| 127 |
+
task_type: EmbeddingTaskTypeOptions | None = None,
|
| 128 |
+
title: str | None = None,
|
| 129 |
+
output_dimensionality: int | None = None,
|
| 130 |
+
client: glm.GenerativeServiceClient = None,
|
| 131 |
+
request_options: helper_types.RequestOptionsType | None = None,
|
| 132 |
+
) -> text_types.EmbeddingDict | text_types.BatchEmbeddingDict:
|
| 133 |
+
"""Calls the API to create embeddings for content passed in.
|
| 134 |
+
|
| 135 |
+
Args:
|
| 136 |
+
model:
|
| 137 |
+
Which [model](https://ai.google.dev/models/gemini#embedding) to
|
| 138 |
+
call, as a string or a `types.Model`.
|
| 139 |
+
|
| 140 |
+
content:
|
| 141 |
+
Content to embed.
|
| 142 |
+
|
| 143 |
+
task_type:
|
| 144 |
+
Optional task type for which the embeddings will be used. Can only
|
| 145 |
+
be set for `models/embedding-001`.
|
| 146 |
+
|
| 147 |
+
title:
|
| 148 |
+
An optional title for the text. Only applicable when task_type is
|
| 149 |
+
`RETRIEVAL_DOCUMENT`.
|
| 150 |
+
|
| 151 |
+
output_dimensionality:
|
| 152 |
+
Optional reduced dimensionality for the output embeddings. If set,
|
| 153 |
+
excessive values from the output embeddings will be truncated from
|
| 154 |
+
the end.
|
| 155 |
+
|
| 156 |
+
request_options:
|
| 157 |
+
Options for the request.
|
| 158 |
+
|
| 159 |
+
Return:
|
| 160 |
+
Dictionary containing the embedding (list of float values) for the
|
| 161 |
+
input content.
|
| 162 |
+
"""
|
| 163 |
+
model = model_types.make_model_name(model)
|
| 164 |
+
|
| 165 |
+
if request_options is None:
|
| 166 |
+
request_options = {}
|
| 167 |
+
|
| 168 |
+
if client is None:
|
| 169 |
+
client = get_default_generative_client()
|
| 170 |
+
|
| 171 |
+
if title and to_task_type(task_type) is not EmbeddingTaskType.RETRIEVAL_DOCUMENT:
|
| 172 |
+
raise ValueError(
|
| 173 |
+
f"Invalid task type: When a title is specified, the task must be of a 'retrieval document' type. Received task type: {task_type} and title: {title}."
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
if output_dimensionality and output_dimensionality < 0:
|
| 177 |
+
raise ValueError(
|
| 178 |
+
f"Invalid value: `output_dimensionality` must be a non-negative integer. Received: {output_dimensionality}."
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
if task_type:
|
| 182 |
+
task_type = to_task_type(task_type)
|
| 183 |
+
|
| 184 |
+
if isinstance(content, Iterable) and not isinstance(content, (str, Mapping)):
|
| 185 |
+
result = {"embedding": []}
|
| 186 |
+
requests = (
|
| 187 |
+
protos.EmbedContentRequest(
|
| 188 |
+
model=model,
|
| 189 |
+
content=content_types.to_content(c),
|
| 190 |
+
task_type=task_type,
|
| 191 |
+
title=title,
|
| 192 |
+
output_dimensionality=output_dimensionality,
|
| 193 |
+
)
|
| 194 |
+
for c in content
|
| 195 |
+
)
|
| 196 |
+
for batch in _batched(requests, EMBEDDING_MAX_BATCH_SIZE):
|
| 197 |
+
embedding_request = protos.BatchEmbedContentsRequest(model=model, requests=batch)
|
| 198 |
+
embedding_response = client.batch_embed_contents(
|
| 199 |
+
embedding_request,
|
| 200 |
+
**request_options,
|
| 201 |
+
)
|
| 202 |
+
embedding_dict = type(embedding_response).to_dict(embedding_response)
|
| 203 |
+
result["embedding"].extend(e["values"] for e in embedding_dict["embeddings"])
|
| 204 |
+
return result
|
| 205 |
+
else:
|
| 206 |
+
embedding_request = protos.EmbedContentRequest(
|
| 207 |
+
model=model,
|
| 208 |
+
content=content_types.to_content(content),
|
| 209 |
+
task_type=task_type,
|
| 210 |
+
title=title,
|
| 211 |
+
output_dimensionality=output_dimensionality,
|
| 212 |
+
)
|
| 213 |
+
embedding_response = client.embed_content(
|
| 214 |
+
embedding_request,
|
| 215 |
+
**request_options,
|
| 216 |
+
)
|
| 217 |
+
embedding_dict = type(embedding_response).to_dict(embedding_response)
|
| 218 |
+
embedding_dict["embedding"] = embedding_dict["embedding"]["values"]
|
| 219 |
+
return embedding_dict
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
@overload
|
| 223 |
+
async def embed_content_async(
|
| 224 |
+
model: model_types.BaseModelNameOptions,
|
| 225 |
+
content: content_types.ContentType,
|
| 226 |
+
task_type: EmbeddingTaskTypeOptions | None = None,
|
| 227 |
+
title: str | None = None,
|
| 228 |
+
output_dimensionality: int | None = None,
|
| 229 |
+
client: glm.GenerativeServiceAsyncClient | None = None,
|
| 230 |
+
request_options: helper_types.RequestOptionsType | None = None,
|
| 231 |
+
) -> text_types.EmbeddingDict: ...
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
@overload
|
| 235 |
+
async def embed_content_async(
|
| 236 |
+
model: model_types.BaseModelNameOptions,
|
| 237 |
+
content: Iterable[content_types.ContentType],
|
| 238 |
+
task_type: EmbeddingTaskTypeOptions | None = None,
|
| 239 |
+
title: str | None = None,
|
| 240 |
+
output_dimensionality: int | None = None,
|
| 241 |
+
client: glm.GenerativeServiceAsyncClient | None = None,
|
| 242 |
+
request_options: helper_types.RequestOptionsType | None = None,
|
| 243 |
+
) -> text_types.BatchEmbeddingDict: ...
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
async def embed_content_async(
|
| 247 |
+
model: model_types.BaseModelNameOptions,
|
| 248 |
+
content: content_types.ContentType | Iterable[content_types.ContentType],
|
| 249 |
+
task_type: EmbeddingTaskTypeOptions | None = None,
|
| 250 |
+
title: str | None = None,
|
| 251 |
+
output_dimensionality: int | None = None,
|
| 252 |
+
client: glm.GenerativeServiceAsyncClient = None,
|
| 253 |
+
request_options: helper_types.RequestOptionsType | None = None,
|
| 254 |
+
) -> text_types.EmbeddingDict | text_types.BatchEmbeddingDict:
|
| 255 |
+
"""Calls the API to create async embeddings for content passed in."""
|
| 256 |
+
|
| 257 |
+
model = model_types.make_model_name(model)
|
| 258 |
+
|
| 259 |
+
if request_options is None:
|
| 260 |
+
request_options = {}
|
| 261 |
+
|
| 262 |
+
if client is None:
|
| 263 |
+
client = get_default_generative_async_client()
|
| 264 |
+
|
| 265 |
+
if title and to_task_type(task_type) is not EmbeddingTaskType.RETRIEVAL_DOCUMENT:
|
| 266 |
+
raise ValueError(
|
| 267 |
+
f"Invalid task type: When a title is specified, the task must be of a 'retrieval document' type. Received task type: {task_type} and title: {title}."
|
| 268 |
+
)
|
| 269 |
+
if output_dimensionality and output_dimensionality < 0:
|
| 270 |
+
raise ValueError(
|
| 271 |
+
f"Invalid value: `output_dimensionality` must be a non-negative integer. Received: {output_dimensionality}."
|
| 272 |
+
)
|
| 273 |
+
|
| 274 |
+
if task_type:
|
| 275 |
+
task_type = to_task_type(task_type)
|
| 276 |
+
|
| 277 |
+
if isinstance(content, Iterable) and not isinstance(content, (str, Mapping)):
|
| 278 |
+
result = {"embedding": []}
|
| 279 |
+
requests = (
|
| 280 |
+
protos.EmbedContentRequest(
|
| 281 |
+
model=model,
|
| 282 |
+
content=content_types.to_content(c),
|
| 283 |
+
task_type=task_type,
|
| 284 |
+
title=title,
|
| 285 |
+
output_dimensionality=output_dimensionality,
|
| 286 |
+
)
|
| 287 |
+
for c in content
|
| 288 |
+
)
|
| 289 |
+
for batch in _batched(requests, EMBEDDING_MAX_BATCH_SIZE):
|
| 290 |
+
embedding_request = protos.BatchEmbedContentsRequest(model=model, requests=batch)
|
| 291 |
+
embedding_response = await client.batch_embed_contents(
|
| 292 |
+
embedding_request,
|
| 293 |
+
**request_options,
|
| 294 |
+
)
|
| 295 |
+
embedding_dict = type(embedding_response).to_dict(embedding_response)
|
| 296 |
+
result["embedding"].extend(e["values"] for e in embedding_dict["embeddings"])
|
| 297 |
+
return result
|
| 298 |
+
else:
|
| 299 |
+
embedding_request = protos.EmbedContentRequest(
|
| 300 |
+
model=model,
|
| 301 |
+
content=content_types.to_content(content),
|
| 302 |
+
task_type=task_type,
|
| 303 |
+
title=title,
|
| 304 |
+
output_dimensionality=output_dimensionality,
|
| 305 |
+
)
|
| 306 |
+
embedding_response = await client.embed_content(
|
| 307 |
+
embedding_request,
|
| 308 |
+
**request_options,
|
| 309 |
+
)
|
| 310 |
+
embedding_dict = type(embedding_response).to_dict(embedding_response)
|
| 311 |
+
embedding_dict["embedding"] = embedding_dict["embedding"]["values"]
|
| 312 |
+
return embedding_dict
|
.venv/lib/python3.11/site-packages/google/generativeai/files.py
ADDED
|
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Copyright 2023 Google LLC
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
from __future__ import annotations
|
| 16 |
+
|
| 17 |
+
import os
|
| 18 |
+
import pathlib
|
| 19 |
+
import mimetypes
|
| 20 |
+
from typing import Iterable
|
| 21 |
+
import logging
|
| 22 |
+
from google.generativeai import protos
|
| 23 |
+
from itertools import islice
|
| 24 |
+
from io import IOBase
|
| 25 |
+
|
| 26 |
+
from google.generativeai.types import file_types
|
| 27 |
+
|
| 28 |
+
from google.generativeai.client import get_default_file_client
|
| 29 |
+
|
| 30 |
+
__all__ = ["upload_file", "get_file", "list_files", "delete_file"]
|
| 31 |
+
|
| 32 |
+
mimetypes.add_type("image/webp", ".webp")
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def upload_file(
|
| 36 |
+
path: str | pathlib.Path | os.PathLike | IOBase,
|
| 37 |
+
*,
|
| 38 |
+
mime_type: str | None = None,
|
| 39 |
+
name: str | None = None,
|
| 40 |
+
display_name: str | None = None,
|
| 41 |
+
resumable: bool = True,
|
| 42 |
+
) -> file_types.File:
|
| 43 |
+
"""Calls the API to upload a file using a supported file service.
|
| 44 |
+
|
| 45 |
+
Args:
|
| 46 |
+
path: The path to the file or a file-like object (e.g., BytesIO) to be uploaded.
|
| 47 |
+
mime_type: The MIME type of the file. If not provided, it will be
|
| 48 |
+
inferred from the file extension.
|
| 49 |
+
name: The name of the file in the destination (e.g., 'files/sample-image').
|
| 50 |
+
If not provided, a system generated ID will be created.
|
| 51 |
+
display_name: Optional display name of the file.
|
| 52 |
+
resumable: Whether to use the resumable upload protocol. By default, this is enabled.
|
| 53 |
+
See details at
|
| 54 |
+
https://googleapis.github.io/google-api-python-client/docs/epy/googleapiclient.http.MediaFileUpload-class.html#resumable
|
| 55 |
+
|
| 56 |
+
Returns:
|
| 57 |
+
file_types.File: The response of the uploaded file.
|
| 58 |
+
"""
|
| 59 |
+
client = get_default_file_client()
|
| 60 |
+
|
| 61 |
+
if isinstance(path, IOBase):
|
| 62 |
+
if mime_type is None:
|
| 63 |
+
raise ValueError(
|
| 64 |
+
"Unknown mime type: When passing a file like object to `path` (instead of a\n"
|
| 65 |
+
" path-like object) you must set the `mime_type` argument"
|
| 66 |
+
)
|
| 67 |
+
else:
|
| 68 |
+
path = pathlib.Path(os.fspath(path))
|
| 69 |
+
|
| 70 |
+
if display_name is None:
|
| 71 |
+
display_name = path.name
|
| 72 |
+
|
| 73 |
+
if mime_type is None:
|
| 74 |
+
mime_type, _ = mimetypes.guess_type(path)
|
| 75 |
+
|
| 76 |
+
if mime_type is None:
|
| 77 |
+
raise ValueError(
|
| 78 |
+
"Unknown mime type: Could not determine the mimetype for your file\n"
|
| 79 |
+
" please set the `mime_type` argument"
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
if name is not None and "/" not in name:
|
| 83 |
+
name = f"files/{name}"
|
| 84 |
+
|
| 85 |
+
response = client.create_file(
|
| 86 |
+
path=path, mime_type=mime_type, name=name, display_name=display_name, resumable=resumable
|
| 87 |
+
)
|
| 88 |
+
return file_types.File(response)
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def list_files(page_size=100) -> Iterable[file_types.File]:
|
| 92 |
+
"""Calls the API to list files using a supported file service."""
|
| 93 |
+
client = get_default_file_client()
|
| 94 |
+
|
| 95 |
+
response = client.list_files(protos.ListFilesRequest(page_size=page_size))
|
| 96 |
+
for proto in response:
|
| 97 |
+
yield file_types.File(proto)
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def get_file(name: str) -> file_types.File:
|
| 101 |
+
"""Calls the API to retrieve a specified file using a supported file service."""
|
| 102 |
+
if "/" not in name:
|
| 103 |
+
name = f"files/{name}"
|
| 104 |
+
client = get_default_file_client()
|
| 105 |
+
return file_types.File(client.get_file(name=name))
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def delete_file(name: str | file_types.File | protos.File):
|
| 109 |
+
"""Calls the API to permanently delete a specified file using a supported file service."""
|
| 110 |
+
if isinstance(name, (file_types.File, protos.File)):
|
| 111 |
+
name = name.name
|
| 112 |
+
elif "/" not in name:
|
| 113 |
+
name = f"files/{name}"
|
| 114 |
+
request = protos.DeleteFileRequest(name=name)
|
| 115 |
+
client = get_default_file_client()
|
| 116 |
+
client.delete_file(request=request)
|
.venv/lib/python3.11/site-packages/google/generativeai/generative_models.py
ADDED
|
@@ -0,0 +1,875 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Classes for working with the Gemini models."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from collections.abc import Iterable
|
| 6 |
+
import textwrap
|
| 7 |
+
from typing import Any, Union, overload
|
| 8 |
+
import reprlib
|
| 9 |
+
|
| 10 |
+
# pylint: disable=bad-continuation, line-too-long
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
import google.api_core.exceptions
|
| 14 |
+
from google.generativeai import protos
|
| 15 |
+
from google.generativeai import client
|
| 16 |
+
|
| 17 |
+
from google.generativeai import caching
|
| 18 |
+
from google.generativeai.types import content_types
|
| 19 |
+
from google.generativeai.types import generation_types
|
| 20 |
+
from google.generativeai.types import helper_types
|
| 21 |
+
from google.generativeai.types import safety_types
|
| 22 |
+
|
| 23 |
+
_USER_ROLE = "user"
|
| 24 |
+
_MODEL_ROLE = "model"
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class GenerativeModel:
|
| 28 |
+
"""
|
| 29 |
+
The `genai.GenerativeModel` class wraps default parameters for calls to
|
| 30 |
+
`GenerativeModel.generate_content`, `GenerativeModel.count_tokens`, and
|
| 31 |
+
`GenerativeModel.start_chat`.
|
| 32 |
+
|
| 33 |
+
This family of functionality is designed to support multi-turn conversations, and multimodal
|
| 34 |
+
requests. What media-types are supported for input and output is model-dependant.
|
| 35 |
+
|
| 36 |
+
>>> import google.generativeai as genai
|
| 37 |
+
>>> import PIL.Image
|
| 38 |
+
>>> genai.configure(api_key='YOUR_API_KEY')
|
| 39 |
+
>>> model = genai.GenerativeModel('models/gemini-1.5-flash')
|
| 40 |
+
>>> result = model.generate_content('Tell me a story about a magic backpack')
|
| 41 |
+
>>> result.text
|
| 42 |
+
"In the quaint little town of Lakeside, there lived a young girl named Lily..."
|
| 43 |
+
|
| 44 |
+
Multimodal input:
|
| 45 |
+
|
| 46 |
+
>>> model = genai.GenerativeModel('models/gemini-1.5-flash')
|
| 47 |
+
>>> result = model.generate_content([
|
| 48 |
+
... "Give me a recipe for these:", PIL.Image.open('scones.jpeg')])
|
| 49 |
+
>>> result.text
|
| 50 |
+
"**Blueberry Scones** ..."
|
| 51 |
+
|
| 52 |
+
Multi-turn conversation:
|
| 53 |
+
|
| 54 |
+
>>> chat = model.start_chat()
|
| 55 |
+
>>> response = chat.send_message("Hi, I have some questions for you.")
|
| 56 |
+
>>> response.text
|
| 57 |
+
"Sure, I'll do my best to answer your questions..."
|
| 58 |
+
|
| 59 |
+
To list the compatible model names use:
|
| 60 |
+
|
| 61 |
+
>>> for m in genai.list_models():
|
| 62 |
+
... if 'generateContent' in m.supported_generation_methods:
|
| 63 |
+
... print(m.name)
|
| 64 |
+
|
| 65 |
+
Arguments:
|
| 66 |
+
model_name: The name of the model to query. To list compatible models use
|
| 67 |
+
safety_settings: Sets the default safety filters. This controls which content is blocked
|
| 68 |
+
by the api before being returned.
|
| 69 |
+
generation_config: A `genai.GenerationConfig` setting the default generation parameters to
|
| 70 |
+
use.
|
| 71 |
+
"""
|
| 72 |
+
|
| 73 |
+
def __init__(
|
| 74 |
+
self,
|
| 75 |
+
model_name: str = "gemini-1.5-flash-002",
|
| 76 |
+
safety_settings: safety_types.SafetySettingOptions | None = None,
|
| 77 |
+
generation_config: generation_types.GenerationConfigType | None = None,
|
| 78 |
+
tools: content_types.FunctionLibraryType | None = None,
|
| 79 |
+
tool_config: content_types.ToolConfigType | None = None,
|
| 80 |
+
system_instruction: content_types.ContentType | None = None,
|
| 81 |
+
):
|
| 82 |
+
if "/" not in model_name:
|
| 83 |
+
model_name = "models/" + model_name
|
| 84 |
+
self._model_name = model_name
|
| 85 |
+
self._safety_settings = safety_types.to_easy_safety_dict(safety_settings)
|
| 86 |
+
self._generation_config = generation_types.to_generation_config_dict(generation_config)
|
| 87 |
+
self._tools = content_types.to_function_library(tools)
|
| 88 |
+
|
| 89 |
+
if tool_config is None:
|
| 90 |
+
self._tool_config = None
|
| 91 |
+
else:
|
| 92 |
+
self._tool_config = content_types.to_tool_config(tool_config)
|
| 93 |
+
|
| 94 |
+
if system_instruction is None:
|
| 95 |
+
self._system_instruction = None
|
| 96 |
+
else:
|
| 97 |
+
self._system_instruction = content_types.to_content(system_instruction)
|
| 98 |
+
|
| 99 |
+
self._client = None
|
| 100 |
+
self._async_client = None
|
| 101 |
+
|
| 102 |
+
@property
|
| 103 |
+
def cached_content(self) -> str:
|
| 104 |
+
return getattr(self, "_cached_content", None)
|
| 105 |
+
|
| 106 |
+
@property
|
| 107 |
+
def model_name(self):
|
| 108 |
+
return self._model_name
|
| 109 |
+
|
| 110 |
+
def __str__(self):
|
| 111 |
+
def maybe_text(content):
|
| 112 |
+
if content and len(content.parts) and (t := content.parts[0].text):
|
| 113 |
+
return repr(t)
|
| 114 |
+
return content
|
| 115 |
+
|
| 116 |
+
return textwrap.dedent(
|
| 117 |
+
f"""\
|
| 118 |
+
genai.GenerativeModel(
|
| 119 |
+
model_name='{self.model_name}',
|
| 120 |
+
generation_config={self._generation_config},
|
| 121 |
+
safety_settings={self._safety_settings},
|
| 122 |
+
tools={self._tools},
|
| 123 |
+
system_instruction={maybe_text(self._system_instruction)},
|
| 124 |
+
cached_content={self.cached_content}
|
| 125 |
+
)"""
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
__repr__ = __str__
|
| 129 |
+
|
| 130 |
+
def _prepare_request(
|
| 131 |
+
self,
|
| 132 |
+
*,
|
| 133 |
+
contents: content_types.ContentsType,
|
| 134 |
+
generation_config: generation_types.GenerationConfigType | None = None,
|
| 135 |
+
safety_settings: safety_types.SafetySettingOptions | None = None,
|
| 136 |
+
tools: content_types.FunctionLibraryType | None,
|
| 137 |
+
tool_config: content_types.ToolConfigType | None,
|
| 138 |
+
) -> protos.GenerateContentRequest:
|
| 139 |
+
"""Creates a `protos.GenerateContentRequest` from raw inputs."""
|
| 140 |
+
if hasattr(self, "_cached_content") and any([self._system_instruction, tools, tool_config]):
|
| 141 |
+
raise ValueError(
|
| 142 |
+
"`tools`, `tool_config`, `system_instruction` cannot be set on a model instantiated with `cached_content` as its context."
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
tools_lib = self._get_tools_lib(tools)
|
| 146 |
+
if tools_lib is not None:
|
| 147 |
+
tools_lib = tools_lib.to_proto()
|
| 148 |
+
|
| 149 |
+
if tool_config is None:
|
| 150 |
+
tool_config = self._tool_config
|
| 151 |
+
else:
|
| 152 |
+
tool_config = content_types.to_tool_config(tool_config)
|
| 153 |
+
|
| 154 |
+
contents = content_types.to_contents(contents)
|
| 155 |
+
|
| 156 |
+
generation_config = generation_types.to_generation_config_dict(generation_config)
|
| 157 |
+
merged_gc = self._generation_config.copy()
|
| 158 |
+
merged_gc.update(generation_config)
|
| 159 |
+
|
| 160 |
+
safety_settings = safety_types.to_easy_safety_dict(safety_settings)
|
| 161 |
+
merged_ss = self._safety_settings.copy()
|
| 162 |
+
merged_ss.update(safety_settings)
|
| 163 |
+
merged_ss = safety_types.normalize_safety_settings(merged_ss)
|
| 164 |
+
|
| 165 |
+
return protos.GenerateContentRequest(
|
| 166 |
+
model=self._model_name,
|
| 167 |
+
contents=contents,
|
| 168 |
+
generation_config=merged_gc,
|
| 169 |
+
safety_settings=merged_ss,
|
| 170 |
+
tools=tools_lib,
|
| 171 |
+
tool_config=tool_config,
|
| 172 |
+
system_instruction=self._system_instruction,
|
| 173 |
+
cached_content=self.cached_content,
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
def _get_tools_lib(
|
| 177 |
+
self, tools: content_types.FunctionLibraryType
|
| 178 |
+
) -> content_types.FunctionLibrary | None:
|
| 179 |
+
if tools is None:
|
| 180 |
+
return self._tools
|
| 181 |
+
else:
|
| 182 |
+
return content_types.to_function_library(tools)
|
| 183 |
+
|
| 184 |
+
@overload
|
| 185 |
+
@classmethod
|
| 186 |
+
def from_cached_content(
|
| 187 |
+
cls,
|
| 188 |
+
cached_content: str,
|
| 189 |
+
*,
|
| 190 |
+
generation_config: generation_types.GenerationConfigType | None = None,
|
| 191 |
+
safety_settings: safety_types.SafetySettingOptions | None = None,
|
| 192 |
+
) -> GenerativeModel: ...
|
| 193 |
+
|
| 194 |
+
@overload
|
| 195 |
+
@classmethod
|
| 196 |
+
def from_cached_content(
|
| 197 |
+
cls,
|
| 198 |
+
cached_content: caching.CachedContent,
|
| 199 |
+
*,
|
| 200 |
+
generation_config: generation_types.GenerationConfigType | None = None,
|
| 201 |
+
safety_settings: safety_types.SafetySettingOptions | None = None,
|
| 202 |
+
) -> GenerativeModel: ...
|
| 203 |
+
|
| 204 |
+
@classmethod
|
| 205 |
+
def from_cached_content(
|
| 206 |
+
cls,
|
| 207 |
+
cached_content: str | caching.CachedContent,
|
| 208 |
+
*,
|
| 209 |
+
generation_config: generation_types.GenerationConfigType | None = None,
|
| 210 |
+
safety_settings: safety_types.SafetySettingOptions | None = None,
|
| 211 |
+
) -> GenerativeModel:
|
| 212 |
+
"""Creates a model with `cached_content` as model's context.
|
| 213 |
+
|
| 214 |
+
Args:
|
| 215 |
+
cached_content: context for the model.
|
| 216 |
+
generation_config: Overrides for the model's generation config.
|
| 217 |
+
safety_settings: Overrides for the model's safety settings.
|
| 218 |
+
|
| 219 |
+
Returns:
|
| 220 |
+
`GenerativeModel` object with `cached_content` as its context.
|
| 221 |
+
"""
|
| 222 |
+
if isinstance(cached_content, str):
|
| 223 |
+
cached_content = caching.CachedContent.get(name=cached_content)
|
| 224 |
+
|
| 225 |
+
# call __init__ to set the model's `generation_config`, `safety_settings`.
|
| 226 |
+
# `model_name` will be the name of the model for which the `cached_content` was created.
|
| 227 |
+
self = cls(
|
| 228 |
+
model_name=cached_content.model,
|
| 229 |
+
generation_config=generation_config,
|
| 230 |
+
safety_settings=safety_settings,
|
| 231 |
+
)
|
| 232 |
+
|
| 233 |
+
# set the model's context.
|
| 234 |
+
setattr(self, "_cached_content", cached_content.name)
|
| 235 |
+
return self
|
| 236 |
+
|
| 237 |
+
def generate_content(
|
| 238 |
+
self,
|
| 239 |
+
contents: content_types.ContentsType,
|
| 240 |
+
*,
|
| 241 |
+
generation_config: generation_types.GenerationConfigType | None = None,
|
| 242 |
+
safety_settings: safety_types.SafetySettingOptions | None = None,
|
| 243 |
+
stream: bool = False,
|
| 244 |
+
tools: content_types.FunctionLibraryType | None = None,
|
| 245 |
+
tool_config: content_types.ToolConfigType | None = None,
|
| 246 |
+
request_options: helper_types.RequestOptionsType | None = None,
|
| 247 |
+
) -> generation_types.GenerateContentResponse:
|
| 248 |
+
"""A multipurpose function to generate responses from the model.
|
| 249 |
+
|
| 250 |
+
This `GenerativeModel.generate_content` method can handle multimodal input, and multi-turn
|
| 251 |
+
conversations.
|
| 252 |
+
|
| 253 |
+
>>> model = genai.GenerativeModel('models/gemini-1.5-flash')
|
| 254 |
+
>>> response = model.generate_content('Tell me a story about a magic backpack')
|
| 255 |
+
>>> response.text
|
| 256 |
+
|
| 257 |
+
### Streaming
|
| 258 |
+
|
| 259 |
+
This method supports streaming with the `stream=True`. The result has the same type as the non streaming case,
|
| 260 |
+
but you can iterate over the response chunks as they become available:
|
| 261 |
+
|
| 262 |
+
>>> response = model.generate_content('Tell me a story about a magic backpack', stream=True)
|
| 263 |
+
>>> for chunk in response:
|
| 264 |
+
... print(chunk.text)
|
| 265 |
+
|
| 266 |
+
### Multi-turn
|
| 267 |
+
|
| 268 |
+
This method supports multi-turn chats but is **stateless**: the entire conversation history needs to be sent with each
|
| 269 |
+
request. This takes some manual management but gives you complete control:
|
| 270 |
+
|
| 271 |
+
>>> messages = [{'role':'user', 'parts': ['hello']}]
|
| 272 |
+
>>> response = model.generate_content(messages) # "Hello, how can I help"
|
| 273 |
+
>>> messages.append(response.candidates[0].content)
|
| 274 |
+
>>> messages.append({'role':'user', 'parts': ['How does quantum physics work?']})
|
| 275 |
+
>>> response = model.generate_content(messages)
|
| 276 |
+
|
| 277 |
+
For a simpler multi-turn interface see `GenerativeModel.start_chat`.
|
| 278 |
+
|
| 279 |
+
### Input type flexibility
|
| 280 |
+
|
| 281 |
+
While the underlying API strictly expects a `list[protos.Content]` objects, this method
|
| 282 |
+
will convert the user input into the correct type. The hierarchy of types that can be
|
| 283 |
+
converted is below. Any of these objects can be passed as an equivalent `dict`.
|
| 284 |
+
|
| 285 |
+
* `Iterable[protos.Content]`
|
| 286 |
+
* `protos.Content`
|
| 287 |
+
* `Iterable[protos.Part]`
|
| 288 |
+
* `protos.Part`
|
| 289 |
+
* `str`, `Image`, or `protos.Blob`
|
| 290 |
+
|
| 291 |
+
In an `Iterable[protos.Content]` each `content` is a separate message.
|
| 292 |
+
But note that an `Iterable[protos.Part]` is taken as the parts of a single message.
|
| 293 |
+
|
| 294 |
+
Arguments:
|
| 295 |
+
contents: The contents serving as the model's prompt.
|
| 296 |
+
generation_config: Overrides for the model's generation config.
|
| 297 |
+
safety_settings: Overrides for the model's safety settings.
|
| 298 |
+
stream: If True, yield response chunks as they are generated.
|
| 299 |
+
tools: `protos.Tools` more info coming soon.
|
| 300 |
+
request_options: Options for the request.
|
| 301 |
+
"""
|
| 302 |
+
if not contents:
|
| 303 |
+
raise TypeError("contents must not be empty")
|
| 304 |
+
|
| 305 |
+
request = self._prepare_request(
|
| 306 |
+
contents=contents,
|
| 307 |
+
generation_config=generation_config,
|
| 308 |
+
safety_settings=safety_settings,
|
| 309 |
+
tools=tools,
|
| 310 |
+
tool_config=tool_config,
|
| 311 |
+
)
|
| 312 |
+
|
| 313 |
+
if request.contents and not request.contents[-1].role:
|
| 314 |
+
request.contents[-1].role = _USER_ROLE
|
| 315 |
+
|
| 316 |
+
if self._client is None:
|
| 317 |
+
self._client = client.get_default_generative_client()
|
| 318 |
+
|
| 319 |
+
if request_options is None:
|
| 320 |
+
request_options = {}
|
| 321 |
+
|
| 322 |
+
try:
|
| 323 |
+
if stream:
|
| 324 |
+
with generation_types.rewrite_stream_error():
|
| 325 |
+
iterator = self._client.stream_generate_content(
|
| 326 |
+
request,
|
| 327 |
+
**request_options,
|
| 328 |
+
)
|
| 329 |
+
return generation_types.GenerateContentResponse.from_iterator(iterator)
|
| 330 |
+
else:
|
| 331 |
+
response = self._client.generate_content(
|
| 332 |
+
request,
|
| 333 |
+
**request_options,
|
| 334 |
+
)
|
| 335 |
+
return generation_types.GenerateContentResponse.from_response(response)
|
| 336 |
+
except google.api_core.exceptions.InvalidArgument as e:
|
| 337 |
+
if e.message.startswith("Request payload size exceeds the limit:"):
|
| 338 |
+
e.message += (
|
| 339 |
+
" The file size is too large. Please use the File API to upload your files instead. "
|
| 340 |
+
"Example: `f = genai.upload_file(path); m.generate_content(['tell me about this file:', f])`"
|
| 341 |
+
)
|
| 342 |
+
raise
|
| 343 |
+
|
| 344 |
+
async def generate_content_async(
|
| 345 |
+
self,
|
| 346 |
+
contents: content_types.ContentsType,
|
| 347 |
+
*,
|
| 348 |
+
generation_config: generation_types.GenerationConfigType | None = None,
|
| 349 |
+
safety_settings: safety_types.SafetySettingOptions | None = None,
|
| 350 |
+
stream: bool = False,
|
| 351 |
+
tools: content_types.FunctionLibraryType | None = None,
|
| 352 |
+
tool_config: content_types.ToolConfigType | None = None,
|
| 353 |
+
request_options: helper_types.RequestOptionsType | None = None,
|
| 354 |
+
) -> generation_types.AsyncGenerateContentResponse:
|
| 355 |
+
"""The async version of `GenerativeModel.generate_content`."""
|
| 356 |
+
if not contents:
|
| 357 |
+
raise TypeError("contents must not be empty")
|
| 358 |
+
|
| 359 |
+
request = self._prepare_request(
|
| 360 |
+
contents=contents,
|
| 361 |
+
generation_config=generation_config,
|
| 362 |
+
safety_settings=safety_settings,
|
| 363 |
+
tools=tools,
|
| 364 |
+
tool_config=tool_config,
|
| 365 |
+
)
|
| 366 |
+
|
| 367 |
+
if request.contents and not request.contents[-1].role:
|
| 368 |
+
request.contents[-1].role = _USER_ROLE
|
| 369 |
+
|
| 370 |
+
if self._async_client is None:
|
| 371 |
+
self._async_client = client.get_default_generative_async_client()
|
| 372 |
+
|
| 373 |
+
if request_options is None:
|
| 374 |
+
request_options = {}
|
| 375 |
+
|
| 376 |
+
try:
|
| 377 |
+
if stream:
|
| 378 |
+
with generation_types.rewrite_stream_error():
|
| 379 |
+
iterator = await self._async_client.stream_generate_content(
|
| 380 |
+
request,
|
| 381 |
+
**request_options,
|
| 382 |
+
)
|
| 383 |
+
return await generation_types.AsyncGenerateContentResponse.from_aiterator(iterator)
|
| 384 |
+
else:
|
| 385 |
+
response = await self._async_client.generate_content(
|
| 386 |
+
request,
|
| 387 |
+
**request_options,
|
| 388 |
+
)
|
| 389 |
+
return generation_types.AsyncGenerateContentResponse.from_response(response)
|
| 390 |
+
except google.api_core.exceptions.InvalidArgument as e:
|
| 391 |
+
if e.message.startswith("Request payload size exceeds the limit:"):
|
| 392 |
+
e.message += (
|
| 393 |
+
" The file size is too large. Please use the File API to upload your files instead. "
|
| 394 |
+
"Example: `f = genai.upload_file(path); m.generate_content(['tell me about this file:', f])`"
|
| 395 |
+
)
|
| 396 |
+
raise
|
| 397 |
+
|
| 398 |
+
# fmt: off
|
| 399 |
+
def count_tokens(
|
| 400 |
+
self,
|
| 401 |
+
contents: content_types.ContentsType = None,
|
| 402 |
+
*,
|
| 403 |
+
generation_config: generation_types.GenerationConfigType | None = None,
|
| 404 |
+
safety_settings: safety_types.SafetySettingOptions | None = None,
|
| 405 |
+
tools: content_types.FunctionLibraryType | None = None,
|
| 406 |
+
tool_config: content_types.ToolConfigType | None = None,
|
| 407 |
+
request_options: helper_types.RequestOptionsType | None = None,
|
| 408 |
+
) -> protos.CountTokensResponse:
|
| 409 |
+
if request_options is None:
|
| 410 |
+
request_options = {}
|
| 411 |
+
|
| 412 |
+
if self._client is None:
|
| 413 |
+
self._client = client.get_default_generative_client()
|
| 414 |
+
|
| 415 |
+
request = protos.CountTokensRequest(
|
| 416 |
+
model=self.model_name,
|
| 417 |
+
generate_content_request=self._prepare_request(
|
| 418 |
+
contents=contents,
|
| 419 |
+
generation_config=generation_config,
|
| 420 |
+
safety_settings=safety_settings,
|
| 421 |
+
tools=tools,
|
| 422 |
+
tool_config=tool_config,
|
| 423 |
+
))
|
| 424 |
+
return self._client.count_tokens(request, **request_options)
|
| 425 |
+
|
| 426 |
+
async def count_tokens_async(
|
| 427 |
+
self,
|
| 428 |
+
contents: content_types.ContentsType = None,
|
| 429 |
+
*,
|
| 430 |
+
generation_config: generation_types.GenerationConfigType | None = None,
|
| 431 |
+
safety_settings: safety_types.SafetySettingOptions | None = None,
|
| 432 |
+
tools: content_types.FunctionLibraryType | None = None,
|
| 433 |
+
tool_config: content_types.ToolConfigType | None = None,
|
| 434 |
+
request_options: helper_types.RequestOptionsType | None = None,
|
| 435 |
+
) -> protos.CountTokensResponse:
|
| 436 |
+
if request_options is None:
|
| 437 |
+
request_options = {}
|
| 438 |
+
|
| 439 |
+
if self._async_client is None:
|
| 440 |
+
self._async_client = client.get_default_generative_async_client()
|
| 441 |
+
|
| 442 |
+
request = protos.CountTokensRequest(
|
| 443 |
+
model=self.model_name,
|
| 444 |
+
generate_content_request=self._prepare_request(
|
| 445 |
+
contents=contents,
|
| 446 |
+
generation_config=generation_config,
|
| 447 |
+
safety_settings=safety_settings,
|
| 448 |
+
tools=tools,
|
| 449 |
+
tool_config=tool_config,
|
| 450 |
+
))
|
| 451 |
+
return await self._async_client.count_tokens(request, **request_options)
|
| 452 |
+
|
| 453 |
+
# fmt: on
|
| 454 |
+
|
| 455 |
+
def start_chat(
|
| 456 |
+
self,
|
| 457 |
+
*,
|
| 458 |
+
history: Iterable[content_types.StrictContentType] | None = None,
|
| 459 |
+
enable_automatic_function_calling: bool = False,
|
| 460 |
+
) -> ChatSession:
|
| 461 |
+
"""Returns a `genai.ChatSession` attached to this model.
|
| 462 |
+
|
| 463 |
+
>>> model = genai.GenerativeModel()
|
| 464 |
+
>>> chat = model.start_chat(history=[...])
|
| 465 |
+
>>> response = chat.send_message("Hello?")
|
| 466 |
+
|
| 467 |
+
Arguments:
|
| 468 |
+
history: An iterable of `protos.Content` objects, or equivalents to initialize the session.
|
| 469 |
+
"""
|
| 470 |
+
if self._generation_config.get("candidate_count", 1) > 1:
|
| 471 |
+
raise ValueError(
|
| 472 |
+
"Invalid configuration: The chat functionality does not support `candidate_count` greater than 1."
|
| 473 |
+
)
|
| 474 |
+
return ChatSession(
|
| 475 |
+
model=self,
|
| 476 |
+
history=history,
|
| 477 |
+
enable_automatic_function_calling=enable_automatic_function_calling,
|
| 478 |
+
)
|
| 479 |
+
|
| 480 |
+
|
| 481 |
+
class ChatSession:
|
| 482 |
+
"""Contains an ongoing conversation with the model.
|
| 483 |
+
|
| 484 |
+
>>> model = genai.GenerativeModel('models/gemini-1.5-flash')
|
| 485 |
+
>>> chat = model.start_chat()
|
| 486 |
+
>>> response = chat.send_message("Hello")
|
| 487 |
+
>>> print(response.text)
|
| 488 |
+
>>> response = chat.send_message("Hello again")
|
| 489 |
+
>>> print(response.text)
|
| 490 |
+
>>> response = chat.send_message(...
|
| 491 |
+
|
| 492 |
+
This `ChatSession` object collects the messages sent and received, in its
|
| 493 |
+
`ChatSession.history` attribute.
|
| 494 |
+
|
| 495 |
+
Arguments:
|
| 496 |
+
model: The model to use in the chat.
|
| 497 |
+
history: A chat history to initialize the object with.
|
| 498 |
+
"""
|
| 499 |
+
|
| 500 |
+
def __init__(
|
| 501 |
+
self,
|
| 502 |
+
model: GenerativeModel,
|
| 503 |
+
history: Iterable[content_types.StrictContentType] | None = None,
|
| 504 |
+
enable_automatic_function_calling: bool = False,
|
| 505 |
+
):
|
| 506 |
+
self.model: GenerativeModel = model
|
| 507 |
+
self._history: list[protos.Content] = content_types.to_contents(history)
|
| 508 |
+
self._last_sent: protos.Content | None = None
|
| 509 |
+
self._last_received: generation_types.BaseGenerateContentResponse | None = None
|
| 510 |
+
self.enable_automatic_function_calling = enable_automatic_function_calling
|
| 511 |
+
|
| 512 |
+
def send_message(
|
| 513 |
+
self,
|
| 514 |
+
content: content_types.ContentType,
|
| 515 |
+
*,
|
| 516 |
+
generation_config: generation_types.GenerationConfigType = None,
|
| 517 |
+
safety_settings: safety_types.SafetySettingOptions = None,
|
| 518 |
+
stream: bool = False,
|
| 519 |
+
tools: content_types.FunctionLibraryType | None = None,
|
| 520 |
+
tool_config: content_types.ToolConfigType | None = None,
|
| 521 |
+
request_options: helper_types.RequestOptionsType | None = None,
|
| 522 |
+
) -> generation_types.GenerateContentResponse:
|
| 523 |
+
"""Sends the conversation history with the added message and returns the model's response.
|
| 524 |
+
|
| 525 |
+
Appends the request and response to the conversation history.
|
| 526 |
+
|
| 527 |
+
>>> model = genai.GenerativeModel('models/gemini-1.5-flash')
|
| 528 |
+
>>> chat = model.start_chat()
|
| 529 |
+
>>> response = chat.send_message("Hello")
|
| 530 |
+
>>> print(response.text)
|
| 531 |
+
"Hello! How can I assist you today?"
|
| 532 |
+
>>> len(chat.history)
|
| 533 |
+
2
|
| 534 |
+
|
| 535 |
+
Call it with `stream=True` to receive response chunks as they are generated:
|
| 536 |
+
|
| 537 |
+
>>> chat = model.start_chat()
|
| 538 |
+
>>> response = chat.send_message("Explain quantum physics", stream=True)
|
| 539 |
+
>>> for chunk in response:
|
| 540 |
+
... print(chunk.text, end='')
|
| 541 |
+
|
| 542 |
+
Once iteration over chunks is complete, the `response` and `ChatSession` are in states identical to the
|
| 543 |
+
`stream=False` case. Some properties are not available until iteration is complete.
|
| 544 |
+
|
| 545 |
+
Like `GenerativeModel.generate_content` this method lets you override the model's `generation_config` and
|
| 546 |
+
`safety_settings`.
|
| 547 |
+
|
| 548 |
+
Arguments:
|
| 549 |
+
content: The message contents.
|
| 550 |
+
generation_config: Overrides for the model's generation config.
|
| 551 |
+
safety_settings: Overrides for the model's safety settings.
|
| 552 |
+
stream: If True, yield response chunks as they are generated.
|
| 553 |
+
"""
|
| 554 |
+
if request_options is None:
|
| 555 |
+
request_options = {}
|
| 556 |
+
|
| 557 |
+
if self.enable_automatic_function_calling and stream:
|
| 558 |
+
raise NotImplementedError(
|
| 559 |
+
"Unsupported configuration: The `google.generativeai` SDK currently does not support the combination of `stream=True` and `enable_automatic_function_calling=True`."
|
| 560 |
+
)
|
| 561 |
+
|
| 562 |
+
tools_lib = self.model._get_tools_lib(tools)
|
| 563 |
+
|
| 564 |
+
content = content_types.to_content(content)
|
| 565 |
+
|
| 566 |
+
if not content.role:
|
| 567 |
+
content.role = _USER_ROLE
|
| 568 |
+
|
| 569 |
+
history = self.history[:]
|
| 570 |
+
history.append(content)
|
| 571 |
+
|
| 572 |
+
generation_config = generation_types.to_generation_config_dict(generation_config)
|
| 573 |
+
if generation_config.get("candidate_count", 1) > 1:
|
| 574 |
+
raise ValueError(
|
| 575 |
+
"Invalid configuration: The chat functionality does not support `candidate_count` greater than 1."
|
| 576 |
+
)
|
| 577 |
+
|
| 578 |
+
response = self.model.generate_content(
|
| 579 |
+
contents=history,
|
| 580 |
+
generation_config=generation_config,
|
| 581 |
+
safety_settings=safety_settings,
|
| 582 |
+
stream=stream,
|
| 583 |
+
tools=tools_lib,
|
| 584 |
+
tool_config=tool_config,
|
| 585 |
+
request_options=request_options,
|
| 586 |
+
)
|
| 587 |
+
|
| 588 |
+
self._check_response(response=response, stream=stream)
|
| 589 |
+
|
| 590 |
+
if self.enable_automatic_function_calling and tools_lib is not None:
|
| 591 |
+
self.history, content, response = self._handle_afc(
|
| 592 |
+
response=response,
|
| 593 |
+
history=history,
|
| 594 |
+
generation_config=generation_config,
|
| 595 |
+
safety_settings=safety_settings,
|
| 596 |
+
stream=stream,
|
| 597 |
+
tools_lib=tools_lib,
|
| 598 |
+
request_options=request_options,
|
| 599 |
+
)
|
| 600 |
+
|
| 601 |
+
self._last_sent = content
|
| 602 |
+
self._last_received = response
|
| 603 |
+
|
| 604 |
+
return response
|
| 605 |
+
|
| 606 |
+
def _check_response(self, *, response, stream):
|
| 607 |
+
if response.prompt_feedback.block_reason:
|
| 608 |
+
raise generation_types.BlockedPromptException(response.prompt_feedback)
|
| 609 |
+
|
| 610 |
+
if not stream:
|
| 611 |
+
if response.candidates[0].finish_reason not in (
|
| 612 |
+
protos.Candidate.FinishReason.FINISH_REASON_UNSPECIFIED,
|
| 613 |
+
protos.Candidate.FinishReason.STOP,
|
| 614 |
+
protos.Candidate.FinishReason.MAX_TOKENS,
|
| 615 |
+
):
|
| 616 |
+
raise generation_types.StopCandidateException(response.candidates[0])
|
| 617 |
+
|
| 618 |
+
def _get_function_calls(self, response) -> list[protos.FunctionCall]:
|
| 619 |
+
candidates = response.candidates
|
| 620 |
+
if len(candidates) != 1:
|
| 621 |
+
raise ValueError(
|
| 622 |
+
f"Invalid number of candidates: Automatic function calling only works with 1 candidate, but {len(candidates)} were provided."
|
| 623 |
+
)
|
| 624 |
+
parts = candidates[0].content.parts
|
| 625 |
+
function_calls = [part.function_call for part in parts if part and "function_call" in part]
|
| 626 |
+
return function_calls
|
| 627 |
+
|
| 628 |
+
def _handle_afc(
|
| 629 |
+
self,
|
| 630 |
+
*,
|
| 631 |
+
response,
|
| 632 |
+
history,
|
| 633 |
+
generation_config,
|
| 634 |
+
safety_settings,
|
| 635 |
+
stream,
|
| 636 |
+
tools_lib,
|
| 637 |
+
request_options,
|
| 638 |
+
) -> tuple[list[protos.Content], protos.Content, generation_types.BaseGenerateContentResponse]:
|
| 639 |
+
|
| 640 |
+
while function_calls := self._get_function_calls(response):
|
| 641 |
+
if not all(callable(tools_lib[fc]) for fc in function_calls):
|
| 642 |
+
break
|
| 643 |
+
history.append(response.candidates[0].content)
|
| 644 |
+
|
| 645 |
+
function_response_parts: list[protos.Part] = []
|
| 646 |
+
for fc in function_calls:
|
| 647 |
+
fr = tools_lib(fc)
|
| 648 |
+
assert fr is not None, (
|
| 649 |
+
"Unexpected state: The function reference (fr) should never be None. It should only return None if the declaration "
|
| 650 |
+
"is not callable, which is checked earlier in the code."
|
| 651 |
+
)
|
| 652 |
+
function_response_parts.append(fr)
|
| 653 |
+
|
| 654 |
+
send = protos.Content(role=_USER_ROLE, parts=function_response_parts)
|
| 655 |
+
history.append(send)
|
| 656 |
+
|
| 657 |
+
response = self.model.generate_content(
|
| 658 |
+
contents=history,
|
| 659 |
+
generation_config=generation_config,
|
| 660 |
+
safety_settings=safety_settings,
|
| 661 |
+
stream=stream,
|
| 662 |
+
tools=tools_lib,
|
| 663 |
+
request_options=request_options,
|
| 664 |
+
)
|
| 665 |
+
|
| 666 |
+
self._check_response(response=response, stream=stream)
|
| 667 |
+
|
| 668 |
+
*history, content = history
|
| 669 |
+
return history, content, response
|
| 670 |
+
|
| 671 |
+
async def send_message_async(
|
| 672 |
+
self,
|
| 673 |
+
content: content_types.ContentType,
|
| 674 |
+
*,
|
| 675 |
+
generation_config: generation_types.GenerationConfigType = None,
|
| 676 |
+
safety_settings: safety_types.SafetySettingOptions = None,
|
| 677 |
+
stream: bool = False,
|
| 678 |
+
tools: content_types.FunctionLibraryType | None = None,
|
| 679 |
+
tool_config: content_types.ToolConfigType | None = None,
|
| 680 |
+
request_options: helper_types.RequestOptionsType | None = None,
|
| 681 |
+
) -> generation_types.AsyncGenerateContentResponse:
|
| 682 |
+
"""The async version of `ChatSession.send_message`."""
|
| 683 |
+
if request_options is None:
|
| 684 |
+
request_options = {}
|
| 685 |
+
|
| 686 |
+
if self.enable_automatic_function_calling and stream:
|
| 687 |
+
raise NotImplementedError(
|
| 688 |
+
"Unsupported configuration: The `google.generativeai` SDK currently does not support the combination of `stream=True` and `enable_automatic_function_calling=True`."
|
| 689 |
+
)
|
| 690 |
+
|
| 691 |
+
tools_lib = self.model._get_tools_lib(tools)
|
| 692 |
+
|
| 693 |
+
content = content_types.to_content(content)
|
| 694 |
+
|
| 695 |
+
if not content.role:
|
| 696 |
+
content.role = _USER_ROLE
|
| 697 |
+
|
| 698 |
+
history = self.history[:]
|
| 699 |
+
history.append(content)
|
| 700 |
+
|
| 701 |
+
generation_config = generation_types.to_generation_config_dict(generation_config)
|
| 702 |
+
if generation_config.get("candidate_count", 1) > 1:
|
| 703 |
+
raise ValueError(
|
| 704 |
+
"Invalid configuration: The chat functionality does not support `candidate_count` greater than 1."
|
| 705 |
+
)
|
| 706 |
+
|
| 707 |
+
response = await self.model.generate_content_async(
|
| 708 |
+
contents=history,
|
| 709 |
+
generation_config=generation_config,
|
| 710 |
+
safety_settings=safety_settings,
|
| 711 |
+
stream=stream,
|
| 712 |
+
tools=tools_lib,
|
| 713 |
+
tool_config=tool_config,
|
| 714 |
+
request_options=request_options,
|
| 715 |
+
)
|
| 716 |
+
|
| 717 |
+
self._check_response(response=response, stream=stream)
|
| 718 |
+
|
| 719 |
+
if self.enable_automatic_function_calling and tools_lib is not None:
|
| 720 |
+
self.history, content, response = await self._handle_afc_async(
|
| 721 |
+
response=response,
|
| 722 |
+
history=history,
|
| 723 |
+
generation_config=generation_config,
|
| 724 |
+
safety_settings=safety_settings,
|
| 725 |
+
stream=stream,
|
| 726 |
+
tools_lib=tools_lib,
|
| 727 |
+
request_options=request_options,
|
| 728 |
+
)
|
| 729 |
+
|
| 730 |
+
self._last_sent = content
|
| 731 |
+
self._last_received = response
|
| 732 |
+
|
| 733 |
+
return response
|
| 734 |
+
|
| 735 |
+
async def _handle_afc_async(
|
| 736 |
+
self,
|
| 737 |
+
*,
|
| 738 |
+
response,
|
| 739 |
+
history,
|
| 740 |
+
generation_config,
|
| 741 |
+
safety_settings,
|
| 742 |
+
stream,
|
| 743 |
+
tools_lib,
|
| 744 |
+
request_options,
|
| 745 |
+
) -> tuple[list[protos.Content], protos.Content, generation_types.BaseGenerateContentResponse]:
|
| 746 |
+
|
| 747 |
+
while function_calls := self._get_function_calls(response):
|
| 748 |
+
if not all(callable(tools_lib[fc]) for fc in function_calls):
|
| 749 |
+
break
|
| 750 |
+
history.append(response.candidates[0].content)
|
| 751 |
+
|
| 752 |
+
function_response_parts: list[protos.Part] = []
|
| 753 |
+
for fc in function_calls:
|
| 754 |
+
fr = tools_lib(fc)
|
| 755 |
+
assert fr is not None, (
|
| 756 |
+
"Unexpected state: The function reference (fr) should never be None. It should only return None if the declaration "
|
| 757 |
+
"is not callable, which is checked earlier in the code."
|
| 758 |
+
)
|
| 759 |
+
function_response_parts.append(fr)
|
| 760 |
+
|
| 761 |
+
send = protos.Content(role=_USER_ROLE, parts=function_response_parts)
|
| 762 |
+
history.append(send)
|
| 763 |
+
|
| 764 |
+
response = await self.model.generate_content_async(
|
| 765 |
+
contents=history,
|
| 766 |
+
generation_config=generation_config,
|
| 767 |
+
safety_settings=safety_settings,
|
| 768 |
+
stream=stream,
|
| 769 |
+
tools=tools_lib,
|
| 770 |
+
request_options=request_options,
|
| 771 |
+
)
|
| 772 |
+
|
| 773 |
+
self._check_response(response=response, stream=stream)
|
| 774 |
+
|
| 775 |
+
*history, content = history
|
| 776 |
+
return history, content, response
|
| 777 |
+
|
| 778 |
+
def __copy__(self):
|
| 779 |
+
return ChatSession(
|
| 780 |
+
model=self.model,
|
| 781 |
+
# Be sure the copy doesn't share the history.
|
| 782 |
+
history=list(self.history),
|
| 783 |
+
)
|
| 784 |
+
|
| 785 |
+
def rewind(self) -> tuple[protos.Content, protos.Content]:
|
| 786 |
+
"""Removes the last request/response pair from the chat history."""
|
| 787 |
+
if self._last_received is None:
|
| 788 |
+
result = self._history.pop(-2), self._history.pop()
|
| 789 |
+
return result
|
| 790 |
+
else:
|
| 791 |
+
result = self._last_sent, self._last_received.candidates[0].content
|
| 792 |
+
self._last_sent = None
|
| 793 |
+
self._last_received = None
|
| 794 |
+
return result
|
| 795 |
+
|
| 796 |
+
@property
|
| 797 |
+
def last(self) -> generation_types.BaseGenerateContentResponse | None:
|
| 798 |
+
"""returns the last received `genai.GenerateContentResponse`"""
|
| 799 |
+
return self._last_received
|
| 800 |
+
|
| 801 |
+
@property
|
| 802 |
+
def history(self) -> list[protos.Content]:
|
| 803 |
+
"""The chat history."""
|
| 804 |
+
last = self._last_received
|
| 805 |
+
if last is None:
|
| 806 |
+
return self._history
|
| 807 |
+
|
| 808 |
+
if last.candidates[0].finish_reason not in (
|
| 809 |
+
protos.Candidate.FinishReason.FINISH_REASON_UNSPECIFIED,
|
| 810 |
+
protos.Candidate.FinishReason.STOP,
|
| 811 |
+
protos.Candidate.FinishReason.MAX_TOKENS,
|
| 812 |
+
):
|
| 813 |
+
error = generation_types.StopCandidateException(last.candidates[0])
|
| 814 |
+
last._error = error
|
| 815 |
+
|
| 816 |
+
if last._error is not None:
|
| 817 |
+
raise generation_types.BrokenResponseError(
|
| 818 |
+
"Unable to build a coherent chat history due to a broken streaming response. "
|
| 819 |
+
"Refer to the previous exception for details. "
|
| 820 |
+
"To inspect the last response object, use `chat.last`. "
|
| 821 |
+
"To remove the last request/response `Content` objects from the chat, "
|
| 822 |
+
"call `last_send, last_received = chat.rewind()` and continue without it."
|
| 823 |
+
) from last._error
|
| 824 |
+
|
| 825 |
+
sent = self._last_sent
|
| 826 |
+
received = last.candidates[0].content
|
| 827 |
+
if not received.role:
|
| 828 |
+
received.role = _MODEL_ROLE
|
| 829 |
+
self._history.extend([sent, received])
|
| 830 |
+
|
| 831 |
+
self._last_sent = None
|
| 832 |
+
self._last_received = None
|
| 833 |
+
|
| 834 |
+
return self._history
|
| 835 |
+
|
| 836 |
+
@history.setter
|
| 837 |
+
def history(self, history):
|
| 838 |
+
self._history = content_types.to_contents(history)
|
| 839 |
+
self._last_sent = None
|
| 840 |
+
self._last_received = None
|
| 841 |
+
|
| 842 |
+
def __repr__(self) -> str:
|
| 843 |
+
_dict_repr = reprlib.Repr()
|
| 844 |
+
_model = str(self.model).replace("\n", "\n" + " " * 4)
|
| 845 |
+
|
| 846 |
+
def content_repr(x):
|
| 847 |
+
return f"protos.Content({_dict_repr.repr(type(x).to_dict(x))})"
|
| 848 |
+
|
| 849 |
+
try:
|
| 850 |
+
history = list(self.history)
|
| 851 |
+
except (generation_types.BrokenResponseError, generation_types.IncompleteIterationError):
|
| 852 |
+
history = list(self._history)
|
| 853 |
+
|
| 854 |
+
if self._last_sent is not None:
|
| 855 |
+
history.append(self._last_sent)
|
| 856 |
+
history = [content_repr(x) for x in history]
|
| 857 |
+
|
| 858 |
+
last_received = self._last_received
|
| 859 |
+
if last_received is not None:
|
| 860 |
+
if last_received._error is not None:
|
| 861 |
+
history.append("<STREAMING ERROR>")
|
| 862 |
+
else:
|
| 863 |
+
history.append("<STREAMING IN PROGRESS>")
|
| 864 |
+
|
| 865 |
+
_history = ",\n " + f"history=[{', '.join(history)}]\n)"
|
| 866 |
+
|
| 867 |
+
return (
|
| 868 |
+
textwrap.dedent(
|
| 869 |
+
f"""\
|
| 870 |
+
ChatSession(
|
| 871 |
+
model="""
|
| 872 |
+
)
|
| 873 |
+
+ _model
|
| 874 |
+
+ _history
|
| 875 |
+
)
|
.venv/lib/python3.11/site-packages/google/generativeai/models.py
ADDED
|
@@ -0,0 +1,467 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Copyright 2023 Google LLC
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
from __future__ import annotations
|
| 16 |
+
|
| 17 |
+
import typing
|
| 18 |
+
from typing import Any, Literal
|
| 19 |
+
|
| 20 |
+
import google.ai.generativelanguage as glm
|
| 21 |
+
|
| 22 |
+
from google.generativeai import protos
|
| 23 |
+
from google.generativeai import operations
|
| 24 |
+
from google.generativeai.client import get_default_model_client
|
| 25 |
+
from google.generativeai.types import model_types
|
| 26 |
+
from google.generativeai.types import helper_types
|
| 27 |
+
from google.api_core import operation
|
| 28 |
+
from google.api_core import protobuf_helpers
|
| 29 |
+
from google.protobuf import field_mask_pb2
|
| 30 |
+
from google.generativeai.utils import flatten_update_paths
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def get_model(
|
| 34 |
+
name: model_types.AnyModelNameOptions,
|
| 35 |
+
*,
|
| 36 |
+
client=None,
|
| 37 |
+
request_options: helper_types.RequestOptionsType | None = None,
|
| 38 |
+
) -> model_types.Model | model_types.TunedModel:
|
| 39 |
+
"""Calls the API to fetch a model by name.
|
| 40 |
+
|
| 41 |
+
```
|
| 42 |
+
import pprint
|
| 43 |
+
model = genai.get_model('models/gemini-1.5-flash')
|
| 44 |
+
pprint.pprint(model)
|
| 45 |
+
```
|
| 46 |
+
|
| 47 |
+
Args:
|
| 48 |
+
name: The name of the model to fetch. Should start with `models/`
|
| 49 |
+
client: The client to use.
|
| 50 |
+
request_options: Options for the request.
|
| 51 |
+
|
| 52 |
+
Returns:
|
| 53 |
+
A `types.Model`
|
| 54 |
+
"""
|
| 55 |
+
name = model_types.make_model_name(name)
|
| 56 |
+
if name.startswith("models/"):
|
| 57 |
+
return get_base_model(name, client=client, request_options=request_options)
|
| 58 |
+
elif name.startswith("tunedModels/"):
|
| 59 |
+
return get_tuned_model(name, client=client, request_options=request_options)
|
| 60 |
+
else:
|
| 61 |
+
raise ValueError(
|
| 62 |
+
f"Invalid model name: Model names must start with `models/` or `tunedModels/`. Received: {name}"
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def get_base_model(
|
| 67 |
+
name: model_types.BaseModelNameOptions,
|
| 68 |
+
*,
|
| 69 |
+
client=None,
|
| 70 |
+
request_options: helper_types.RequestOptionsType | None = None,
|
| 71 |
+
) -> model_types.Model:
|
| 72 |
+
"""Calls the API to fetch a base model by name.
|
| 73 |
+
|
| 74 |
+
```
|
| 75 |
+
import pprint
|
| 76 |
+
model = genai.get_base_model('models/chat-bison-001')
|
| 77 |
+
pprint.pprint(model)
|
| 78 |
+
```
|
| 79 |
+
|
| 80 |
+
Args:
|
| 81 |
+
name: The name of the model to fetch. Should start with `models/`
|
| 82 |
+
client: The client to use.
|
| 83 |
+
request_options: Options for the request.
|
| 84 |
+
|
| 85 |
+
Returns:
|
| 86 |
+
A `types.Model`.
|
| 87 |
+
"""
|
| 88 |
+
if request_options is None:
|
| 89 |
+
request_options = {}
|
| 90 |
+
|
| 91 |
+
if client is None:
|
| 92 |
+
client = get_default_model_client()
|
| 93 |
+
|
| 94 |
+
name = model_types.make_model_name(name)
|
| 95 |
+
if not name.startswith("models/"):
|
| 96 |
+
raise ValueError(
|
| 97 |
+
f"Invalid model name: Base model names must start with `models/`. Received: {name}"
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
result = client.get_model(name=name, **request_options)
|
| 101 |
+
result = type(result).to_dict(result)
|
| 102 |
+
return model_types.Model(**result)
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def get_tuned_model(
|
| 106 |
+
name: model_types.TunedModelNameOptions,
|
| 107 |
+
*,
|
| 108 |
+
client=None,
|
| 109 |
+
request_options: helper_types.RequestOptionsType | None = None,
|
| 110 |
+
) -> model_types.TunedModel:
|
| 111 |
+
"""Calls the API to fetch a tuned model by name.
|
| 112 |
+
|
| 113 |
+
```
|
| 114 |
+
import pprint
|
| 115 |
+
model = genai.get_tuned_model('tunedModels/gemini-1.5-flash')
|
| 116 |
+
pprint.pprint(model)
|
| 117 |
+
```
|
| 118 |
+
|
| 119 |
+
Args:
|
| 120 |
+
name: The name of the model to fetch. Should start with `tunedModels/`
|
| 121 |
+
client: The client to use.
|
| 122 |
+
request_options: Options for the request.
|
| 123 |
+
|
| 124 |
+
Returns:
|
| 125 |
+
A `types.TunedModel`.
|
| 126 |
+
"""
|
| 127 |
+
if request_options is None:
|
| 128 |
+
request_options = {}
|
| 129 |
+
|
| 130 |
+
if client is None:
|
| 131 |
+
client = get_default_model_client()
|
| 132 |
+
|
| 133 |
+
name = model_types.make_model_name(name)
|
| 134 |
+
|
| 135 |
+
if not name.startswith("tunedModels/"):
|
| 136 |
+
raise ValueError(
|
| 137 |
+
f"Invalid model name: Tuned model names must start with `tunedModels/`. Received: {name}"
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
result = client.get_tuned_model(name=name, **request_options)
|
| 141 |
+
|
| 142 |
+
return model_types.decode_tuned_model(result)
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def get_base_model_name(
|
| 146 |
+
model: model_types.AnyModelNameOptions, client: glm.ModelServiceClient | None = None
|
| 147 |
+
):
|
| 148 |
+
"""Calls the API to fetch the base model name of a model."""
|
| 149 |
+
|
| 150 |
+
if isinstance(model, str):
|
| 151 |
+
if model.startswith("tunedModels/"):
|
| 152 |
+
model = get_model(model, client=client)
|
| 153 |
+
base_model = model.base_model
|
| 154 |
+
else:
|
| 155 |
+
base_model = model
|
| 156 |
+
elif isinstance(model, model_types.TunedModel):
|
| 157 |
+
base_model = model.base_model
|
| 158 |
+
elif isinstance(model, model_types.Model):
|
| 159 |
+
base_model = model.name
|
| 160 |
+
elif isinstance(model, protos.Model):
|
| 161 |
+
base_model = model.name
|
| 162 |
+
elif isinstance(model, protos.TunedModel):
|
| 163 |
+
base_model = getattr(model, "base_model", None)
|
| 164 |
+
if not base_model:
|
| 165 |
+
base_model = model.tuned_model_source.base_model
|
| 166 |
+
else:
|
| 167 |
+
raise TypeError(
|
| 168 |
+
f"Invalid model: The provided model '{model}' is not recognized or supported. "
|
| 169 |
+
"Supported types are: str, model_types.TunedModel, model_types.Model, protos.Model, and protos.TunedModel."
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
return base_model
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
def list_models(
|
| 176 |
+
*,
|
| 177 |
+
page_size: int | None = 50,
|
| 178 |
+
client: glm.ModelServiceClient | None = None,
|
| 179 |
+
request_options: helper_types.RequestOptionsType | None = None,
|
| 180 |
+
) -> model_types.ModelsIterable:
|
| 181 |
+
"""Calls the API to list all available models.
|
| 182 |
+
|
| 183 |
+
```
|
| 184 |
+
import pprint
|
| 185 |
+
for model in genai.list_models():
|
| 186 |
+
pprint.pprint(model)
|
| 187 |
+
```
|
| 188 |
+
|
| 189 |
+
Args:
|
| 190 |
+
page_size: How many `types.Models` to fetch per page (api call).
|
| 191 |
+
client: You may pass a `glm.ModelServiceClient` instead of using the default client.
|
| 192 |
+
request_options: Options for the request.
|
| 193 |
+
|
| 194 |
+
Yields:
|
| 195 |
+
`types.Model` objects.
|
| 196 |
+
|
| 197 |
+
"""
|
| 198 |
+
if request_options is None:
|
| 199 |
+
request_options = {}
|
| 200 |
+
|
| 201 |
+
if client is None:
|
| 202 |
+
client = get_default_model_client()
|
| 203 |
+
|
| 204 |
+
for model in client.list_models(page_size=page_size, **request_options):
|
| 205 |
+
model = type(model).to_dict(model)
|
| 206 |
+
yield model_types.Model(**model)
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
def list_tuned_models(
|
| 210 |
+
*,
|
| 211 |
+
page_size: int | None = 50,
|
| 212 |
+
client: glm.ModelServiceClient | None = None,
|
| 213 |
+
request_options: helper_types.RequestOptionsType | None = None,
|
| 214 |
+
) -> model_types.TunedModelsIterable:
|
| 215 |
+
"""Calls the API to list all tuned models.
|
| 216 |
+
|
| 217 |
+
```
|
| 218 |
+
import pprint
|
| 219 |
+
for model in genai.list_tuned_models():
|
| 220 |
+
pprint.pprint(model)
|
| 221 |
+
```
|
| 222 |
+
|
| 223 |
+
Args:
|
| 224 |
+
page_size: How many `types.Models` to fetch per page (api call).
|
| 225 |
+
client: You may pass a `glm.ModelServiceClient` instead of using the default client.
|
| 226 |
+
request_options: Options for the request.
|
| 227 |
+
|
| 228 |
+
Yields:
|
| 229 |
+
`types.TunedModel` objects.
|
| 230 |
+
"""
|
| 231 |
+
if request_options is None:
|
| 232 |
+
request_options = {}
|
| 233 |
+
|
| 234 |
+
if client is None:
|
| 235 |
+
client = get_default_model_client()
|
| 236 |
+
|
| 237 |
+
for model in client.list_tuned_models(
|
| 238 |
+
page_size=page_size,
|
| 239 |
+
**request_options,
|
| 240 |
+
):
|
| 241 |
+
model = type(model).to_dict(model)
|
| 242 |
+
yield model_types.decode_tuned_model(model)
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
def create_tuned_model(
|
| 246 |
+
source_model: model_types.AnyModelNameOptions,
|
| 247 |
+
training_data: model_types.TuningDataOptions,
|
| 248 |
+
*,
|
| 249 |
+
id: str | None = None,
|
| 250 |
+
display_name: str | None = None,
|
| 251 |
+
description: str | None = None,
|
| 252 |
+
temperature: float | None = None,
|
| 253 |
+
top_p: float | None = None,
|
| 254 |
+
top_k: int | None = None,
|
| 255 |
+
epoch_count: int | None = None,
|
| 256 |
+
batch_size: int | None = None,
|
| 257 |
+
learning_rate: float | None = None,
|
| 258 |
+
input_key: str = "text_input",
|
| 259 |
+
output_key: str = "output",
|
| 260 |
+
client: glm.ModelServiceClient | None = None,
|
| 261 |
+
request_options: helper_types.RequestOptionsType | None = None,
|
| 262 |
+
) -> operations.CreateTunedModelOperation:
|
| 263 |
+
"""Calls the API to initiate a tuning process that optimizes a model for specific data, returning an operation object to track and manage the tuning progress.
|
| 264 |
+
|
| 265 |
+
Since tuning a model can take significant time, this API doesn't wait for the tuning to complete.
|
| 266 |
+
Instead, it returns a `google.api_core.operation.Operation` object that lets you check on the
|
| 267 |
+
status of the tuning job, or wait for it to complete, and check the result.
|
| 268 |
+
|
| 269 |
+
After the job completes you can either find the resulting `TunedModel` object in
|
| 270 |
+
`Operation.result()` or `palm.list_tuned_models` or `palm.get_tuned_model(model_id)`.
|
| 271 |
+
|
| 272 |
+
```
|
| 273 |
+
my_id = "my-tuned-model-id"
|
| 274 |
+
operation = palm.create_tuned_model(
|
| 275 |
+
id = my_id,
|
| 276 |
+
source_model="models/text-bison-001",
|
| 277 |
+
training_data=[{'text_input': 'example input', 'output': 'example output'},...]
|
| 278 |
+
)
|
| 279 |
+
tuned_model=operation.result() # Wait for tuning to finish
|
| 280 |
+
|
| 281 |
+
palm.generate_text(f"tunedModels/{my_id}", prompt="...")
|
| 282 |
+
```
|
| 283 |
+
|
| 284 |
+
Args:
|
| 285 |
+
source_model: The name of the model to tune.
|
| 286 |
+
training_data: The dataset to tune the model on. This must be either:
|
| 287 |
+
* A `protos.Dataset`, or
|
| 288 |
+
* An `Iterable` of:
|
| 289 |
+
*`protos.TuningExample`,
|
| 290 |
+
* `{'text_input': text_input, 'output': output}` dicts
|
| 291 |
+
* `(text_input, output)` tuples.
|
| 292 |
+
* A `Mapping` of `Iterable[str]` - use `input_key` and `output_key` to choose which
|
| 293 |
+
columns to use as the input/output
|
| 294 |
+
* A csv file (will be read with `pd.read_csv` and handles as a `Mapping`
|
| 295 |
+
above). This can be:
|
| 296 |
+
* A local path as a `str` or `pathlib.Path`.
|
| 297 |
+
* A url for a csv file.
|
| 298 |
+
* The url of a Google Sheets file.
|
| 299 |
+
* A JSON file - Its contents will be handled either as an `Iterable` or `Mapping`
|
| 300 |
+
above. This can be:
|
| 301 |
+
* A local path as a `str` or `pathlib.Path`.
|
| 302 |
+
id: The model identifier, used to refer to the model in the API
|
| 303 |
+
`tunedModels/{id}`. Must be unique.
|
| 304 |
+
display_name: A human-readable name for display.
|
| 305 |
+
description: A description of the tuned model.
|
| 306 |
+
temperature: The default temperature for the tuned model, see `types.Model` for details.
|
| 307 |
+
top_p: The default `top_p` for the model, see `types.Model` for details.
|
| 308 |
+
top_k: The default `top_k` for the model, see `types.Model` for details.
|
| 309 |
+
epoch_count: The number of tuning epochs to run. An epoch is a pass over the whole dataset.
|
| 310 |
+
batch_size: The number of examples to use in each training batch.
|
| 311 |
+
learning_rate: The step size multiplier for the gradient updates.
|
| 312 |
+
client: Which client to use.
|
| 313 |
+
request_options: Options for the request.
|
| 314 |
+
|
| 315 |
+
Returns:
|
| 316 |
+
A [`google.api_core.operation.Operation`](https://googleapis.dev/python/google-api-core/latest/operation.html)
|
| 317 |
+
"""
|
| 318 |
+
if request_options is None:
|
| 319 |
+
request_options = {}
|
| 320 |
+
|
| 321 |
+
if client is None:
|
| 322 |
+
client = get_default_model_client()
|
| 323 |
+
|
| 324 |
+
source_model_name = model_types.make_model_name(source_model)
|
| 325 |
+
base_model_name = get_base_model_name(source_model)
|
| 326 |
+
if source_model_name.startswith("models/"):
|
| 327 |
+
source_model = {"base_model": source_model_name}
|
| 328 |
+
elif source_model_name.startswith("tunedModels/"):
|
| 329 |
+
source_model = {
|
| 330 |
+
"tuned_model_source": {
|
| 331 |
+
"tuned_model": source_model_name,
|
| 332 |
+
"base_model": base_model_name,
|
| 333 |
+
}
|
| 334 |
+
}
|
| 335 |
+
else:
|
| 336 |
+
raise ValueError(
|
| 337 |
+
f"Invalid model name: The provided model '{source_model}' does not match any known model patterns such as 'models/' or 'tunedModels/'"
|
| 338 |
+
)
|
| 339 |
+
|
| 340 |
+
training_data = model_types.encode_tuning_data(
|
| 341 |
+
training_data, input_key=input_key, output_key=output_key
|
| 342 |
+
)
|
| 343 |
+
|
| 344 |
+
hyperparameters = protos.Hyperparameters(
|
| 345 |
+
epoch_count=epoch_count,
|
| 346 |
+
batch_size=batch_size,
|
| 347 |
+
learning_rate=learning_rate,
|
| 348 |
+
)
|
| 349 |
+
tuning_task = protos.TuningTask(
|
| 350 |
+
training_data=training_data,
|
| 351 |
+
hyperparameters=hyperparameters,
|
| 352 |
+
)
|
| 353 |
+
|
| 354 |
+
tuned_model = protos.TunedModel(
|
| 355 |
+
**source_model,
|
| 356 |
+
display_name=display_name,
|
| 357 |
+
description=description,
|
| 358 |
+
temperature=temperature,
|
| 359 |
+
top_p=top_p,
|
| 360 |
+
top_k=top_k,
|
| 361 |
+
tuning_task=tuning_task,
|
| 362 |
+
)
|
| 363 |
+
|
| 364 |
+
operation = client.create_tuned_model(
|
| 365 |
+
dict(tuned_model_id=id, tuned_model=tuned_model), **request_options
|
| 366 |
+
)
|
| 367 |
+
|
| 368 |
+
return operations.CreateTunedModelOperation.from_core_operation(operation)
|
| 369 |
+
|
| 370 |
+
|
| 371 |
+
@typing.overload
|
| 372 |
+
def update_tuned_model(
|
| 373 |
+
tuned_model: protos.TunedModel,
|
| 374 |
+
updates: None = None,
|
| 375 |
+
*,
|
| 376 |
+
client: glm.ModelServiceClient | None = None,
|
| 377 |
+
request_options: helper_types.RequestOptionsType | None = None,
|
| 378 |
+
) -> model_types.TunedModel:
|
| 379 |
+
pass
|
| 380 |
+
|
| 381 |
+
|
| 382 |
+
@typing.overload
|
| 383 |
+
def update_tuned_model(
|
| 384 |
+
tuned_model: str,
|
| 385 |
+
updates: dict[str, Any],
|
| 386 |
+
*,
|
| 387 |
+
client: glm.ModelServiceClient | None = None,
|
| 388 |
+
request_options: helper_types.RequestOptionsType | None = None,
|
| 389 |
+
) -> model_types.TunedModel:
|
| 390 |
+
pass
|
| 391 |
+
|
| 392 |
+
|
| 393 |
+
def update_tuned_model(
|
| 394 |
+
tuned_model: str | protos.TunedModel,
|
| 395 |
+
updates: dict[str, Any] | None = None,
|
| 396 |
+
*,
|
| 397 |
+
client: glm.ModelServiceClient | None = None,
|
| 398 |
+
request_options: helper_types.RequestOptionsType | None = None,
|
| 399 |
+
) -> model_types.TunedModel:
|
| 400 |
+
"""Calls the API to push updates to a specified tuned model where only certain attributes are updatable."""
|
| 401 |
+
|
| 402 |
+
if request_options is None:
|
| 403 |
+
request_options = {}
|
| 404 |
+
|
| 405 |
+
if client is None:
|
| 406 |
+
client = get_default_model_client()
|
| 407 |
+
|
| 408 |
+
if isinstance(tuned_model, str):
|
| 409 |
+
name = tuned_model
|
| 410 |
+
if not isinstance(updates, dict):
|
| 411 |
+
raise TypeError(
|
| 412 |
+
f"Invalid argument type: In the function `update_tuned_model(name:str, updates: dict)`, the `updates` argument must be of type `dict`. Received type: {type(updates).__name__}."
|
| 413 |
+
)
|
| 414 |
+
|
| 415 |
+
tuned_model = client.get_tuned_model(name=name, **request_options)
|
| 416 |
+
|
| 417 |
+
updates = flatten_update_paths(updates)
|
| 418 |
+
field_mask = field_mask_pb2.FieldMask()
|
| 419 |
+
for path in updates.keys():
|
| 420 |
+
field_mask.paths.append(path)
|
| 421 |
+
for path, value in updates.items():
|
| 422 |
+
_apply_update(tuned_model, path, value)
|
| 423 |
+
elif isinstance(tuned_model, protos.TunedModel):
|
| 424 |
+
if updates is not None:
|
| 425 |
+
raise ValueError(
|
| 426 |
+
"Invalid argument: When calling `update_tuned_model(tuned_model:protos.TunedModel, updates=None)`, "
|
| 427 |
+
"the `updates` argument must not be set."
|
| 428 |
+
)
|
| 429 |
+
|
| 430 |
+
name = tuned_model.name
|
| 431 |
+
was = client.get_tuned_model(name=name)
|
| 432 |
+
field_mask = protobuf_helpers.field_mask(was._pb, tuned_model._pb)
|
| 433 |
+
else:
|
| 434 |
+
raise TypeError(
|
| 435 |
+
"Invalid argument type: In the function `update_tuned_model(tuned_model:dict|protos.TunedModel)`, the "
|
| 436 |
+
f"`tuned_model` argument must be of type `dict` or `protos.TunedModel`. Received type: {type(tuned_model).__name__}."
|
| 437 |
+
)
|
| 438 |
+
|
| 439 |
+
result = client.update_tuned_model(
|
| 440 |
+
protos.UpdateTunedModelRequest(tuned_model=tuned_model, update_mask=field_mask),
|
| 441 |
+
**request_options,
|
| 442 |
+
)
|
| 443 |
+
return model_types.decode_tuned_model(result)
|
| 444 |
+
|
| 445 |
+
|
| 446 |
+
def _apply_update(thing, path, value):
|
| 447 |
+
parts = path.split(".")
|
| 448 |
+
for part in parts[:-1]:
|
| 449 |
+
thing = getattr(thing, part)
|
| 450 |
+
setattr(thing, parts[-1], value)
|
| 451 |
+
|
| 452 |
+
|
| 453 |
+
def delete_tuned_model(
|
| 454 |
+
tuned_model: model_types.TunedModelNameOptions,
|
| 455 |
+
client: glm.ModelServiceClient | None = None,
|
| 456 |
+
request_options: helper_types.RequestOptionsType | None = None,
|
| 457 |
+
) -> None:
|
| 458 |
+
"""Calls the API to delete a specified tuned model"""
|
| 459 |
+
|
| 460 |
+
if request_options is None:
|
| 461 |
+
request_options = {}
|
| 462 |
+
|
| 463 |
+
if client is None:
|
| 464 |
+
client = get_default_model_client()
|
| 465 |
+
|
| 466 |
+
name = model_types.make_model_name(tuned_model)
|
| 467 |
+
client.delete_tuned_model(name=name, **request_options)
|
.venv/lib/python3.11/site-packages/google/generativeai/notebook/__init__.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Copyright 2023 Google LLC
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Notebook extensions for Generative AI."""
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def load_ipython_extension(ipython):
|
| 19 |
+
"""Register the Colab Magic extension to support %load_ext."""
|
| 20 |
+
# pylint: disable-next=g-import-not-at-top
|
| 21 |
+
from google.generativeai.notebook import magics
|
| 22 |
+
|
| 23 |
+
ipython.register_magics(magics.Magics)
|
| 24 |
+
|
| 25 |
+
# Since we're in an interactive environment, make the tables prettier.
|
| 26 |
+
try:
|
| 27 |
+
# pylint: disable-next=g-import-not-at-top
|
| 28 |
+
from google import colab # type: ignore
|
| 29 |
+
|
| 30 |
+
colab.data_table.enable_dataframe_formatter()
|
| 31 |
+
except ImportError:
|
| 32 |
+
pass
|
.venv/lib/python3.11/site-packages/google/generativeai/notebook/__pycache__/command.cpython-311.pyc
ADDED
|
Binary file (1.77 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/google/generativeai/notebook/__pycache__/compare_cmd.cpython-311.pyc
ADDED
|
Binary file (2.84 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/google/generativeai/notebook/__pycache__/eval_cmd.cpython-311.pyc
ADDED
|
Binary file (3.01 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/google/generativeai/notebook/__pycache__/flag_def.cpython-311.pyc
ADDED
|
Binary file (21.9 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/google/generativeai/notebook/__pycache__/input_utils.cpython-311.pyc
ADDED
|
Binary file (3.62 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/google/generativeai/notebook/__pycache__/magics_engine.cpython-311.pyc
ADDED
|
Binary file (5.64 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/google/generativeai/notebook/__pycache__/parsed_args_lib.cpython-311.pyc
ADDED
|
Binary file (3.21 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/google/generativeai/notebook/__pycache__/run_cmd.cpython-311.pyc
ADDED
|
Binary file (3 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/google/generativeai/notebook/__pycache__/sheets_id.cpython-311.pyc
ADDED
|
Binary file (4.92 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/google/generativeai/notebook/__pycache__/sheets_utils.cpython-311.pyc
ADDED
|
Binary file (6.03 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/google/generativeai/notebook/cmd_line_parser.py
ADDED
|
@@ -0,0 +1,557 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Copyright 2023 Google LLC
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Parses an LLM command line."""
|
| 16 |
+
from __future__ import annotations
|
| 17 |
+
|
| 18 |
+
import argparse
|
| 19 |
+
import shlex
|
| 20 |
+
import sys
|
| 21 |
+
from typing import AbstractSet, Any, Callable, MutableMapping, Sequence
|
| 22 |
+
|
| 23 |
+
from google.generativeai.notebook import argument_parser
|
| 24 |
+
from google.generativeai.notebook import flag_def
|
| 25 |
+
from google.generativeai.notebook import input_utils
|
| 26 |
+
from google.generativeai.notebook import model_registry
|
| 27 |
+
from google.generativeai.notebook import output_utils
|
| 28 |
+
from google.generativeai.notebook import parsed_args_lib
|
| 29 |
+
from google.generativeai.notebook import post_process_utils
|
| 30 |
+
from google.generativeai.notebook import py_utils
|
| 31 |
+
from google.generativeai.notebook import sheets_utils
|
| 32 |
+
from google.generativeai.notebook.lib import llm_function
|
| 33 |
+
from google.generativeai.notebook.lib import llmfn_inputs_source
|
| 34 |
+
from google.generativeai.notebook.lib import llmfn_outputs
|
| 35 |
+
from google.generativeai.notebook.lib import model as model_lib
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
_MIN_CANDIDATE_COUNT = 1
|
| 39 |
+
_MAX_CANDIDATE_COUNT = 8
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def _validate_input_source_against_placeholders(
|
| 43 |
+
source: llmfn_inputs_source.LLMFnInputsSource,
|
| 44 |
+
placeholders: AbstractSet[str],
|
| 45 |
+
) -> None:
|
| 46 |
+
for inputs in source.to_normalized_inputs():
|
| 47 |
+
for keyword in placeholders:
|
| 48 |
+
if keyword not in inputs:
|
| 49 |
+
raise ValueError('Placeholder "{}" not found in input'.format(keyword))
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def _get_resolve_input_from_py_var_fn(
|
| 53 |
+
placeholders: AbstractSet[str] | None,
|
| 54 |
+
) -> Callable[[str], llmfn_inputs_source.LLMFnInputsSource]:
|
| 55 |
+
def _fn(var_name: str) -> llmfn_inputs_source.LLMFnInputsSource:
|
| 56 |
+
source = input_utils.get_inputs_source_from_py_var(var_name)
|
| 57 |
+
if placeholders:
|
| 58 |
+
_validate_input_source_against_placeholders(source, placeholders)
|
| 59 |
+
return source
|
| 60 |
+
|
| 61 |
+
return _fn
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def _resolve_compare_fn_var(
|
| 65 |
+
name: str,
|
| 66 |
+
) -> tuple[str, parsed_args_lib.TextResultCompareFn]:
|
| 67 |
+
"""Resolves a value passed into --compare_fn."""
|
| 68 |
+
fn = py_utils.get_py_var(name)
|
| 69 |
+
if not isinstance(fn, Callable):
|
| 70 |
+
raise ValueError('Variable "{}" does not contain a Callable object'.format(name))
|
| 71 |
+
|
| 72 |
+
return name, fn
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def _resolve_ground_truth_var(name: str) -> Sequence[str]:
|
| 76 |
+
"""Resolves a value passed into --ground_truth."""
|
| 77 |
+
value = py_utils.get_py_var(name)
|
| 78 |
+
|
| 79 |
+
# "str" and "bytes" are also Sequences but we want an actual Sequence of
|
| 80 |
+
# strings, like a list.
|
| 81 |
+
if not isinstance(value, Sequence) or isinstance(value, str) or isinstance(value, bytes):
|
| 82 |
+
raise ValueError('Variable "{}" does not contain a Sequence of strings'.format(name))
|
| 83 |
+
for x in value:
|
| 84 |
+
if not isinstance(x, str):
|
| 85 |
+
raise ValueError('Variable "{}" does not contain a Sequence of strings'.format(name))
|
| 86 |
+
return value
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def _get_resolve_sheets_inputs_fn(
|
| 90 |
+
placeholders: AbstractSet[str] | None,
|
| 91 |
+
) -> Callable[[str], llmfn_inputs_source.LLMFnInputsSource]:
|
| 92 |
+
def _fn(value: str) -> llmfn_inputs_source.LLMFnInputsSource:
|
| 93 |
+
sheets_id = sheets_utils.get_sheets_id_from_str(value)
|
| 94 |
+
source = sheets_utils.SheetsInputs(sheets_id)
|
| 95 |
+
if placeholders:
|
| 96 |
+
_validate_input_source_against_placeholders(source, placeholders)
|
| 97 |
+
return source
|
| 98 |
+
|
| 99 |
+
return _fn
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def _resolve_sheets_outputs(value: str) -> llmfn_outputs.LLMFnOutputsSink:
|
| 103 |
+
sheets_id = sheets_utils.get_sheets_id_from_str(value)
|
| 104 |
+
return sheets_utils.SheetsOutputs(sheets_id)
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def _add_model_flags(
|
| 108 |
+
parser: argparse.ArgumentParser,
|
| 109 |
+
) -> None:
|
| 110 |
+
"""Adds flags that are related to model selection and config."""
|
| 111 |
+
flag_def.EnumFlagDef(
|
| 112 |
+
name="model_type",
|
| 113 |
+
short_name="mt",
|
| 114 |
+
enum_type=model_registry.ModelName,
|
| 115 |
+
default_value=model_registry.ModelRegistry.DEFAULT_MODEL,
|
| 116 |
+
help_msg="The type of model to use.",
|
| 117 |
+
).add_argument_to_parser(parser)
|
| 118 |
+
|
| 119 |
+
def _check_is_greater_than_or_equal_to_zero(x: float) -> float:
|
| 120 |
+
if x < 0:
|
| 121 |
+
raise ValueError("Value should be greater than or equal to zero, got {}".format(x))
|
| 122 |
+
return x
|
| 123 |
+
|
| 124 |
+
flag_def.SingleValueFlagDef(
|
| 125 |
+
name="temperature",
|
| 126 |
+
short_name="t",
|
| 127 |
+
parse_type=float,
|
| 128 |
+
# Use None for default value to indicate that this will use the default
|
| 129 |
+
# value in Text service.
|
| 130 |
+
default_value=None,
|
| 131 |
+
parse_to_dest_type_fn=_check_is_greater_than_or_equal_to_zero,
|
| 132 |
+
help_msg=(
|
| 133 |
+
"Controls the randomness of the output. Must be positive. Typical"
|
| 134 |
+
" values are in the range: [0.0, 1.0]. Higher values produce a more"
|
| 135 |
+
" random and varied response. A temperature of zero will be"
|
| 136 |
+
" deterministic."
|
| 137 |
+
),
|
| 138 |
+
).add_argument_to_parser(parser)
|
| 139 |
+
|
| 140 |
+
flag_def.SingleValueFlagDef(
|
| 141 |
+
name="model",
|
| 142 |
+
short_name="m",
|
| 143 |
+
default_value=None,
|
| 144 |
+
help_msg=(
|
| 145 |
+
"The name of the model to use. If not provided, a default model will" " be used."
|
| 146 |
+
),
|
| 147 |
+
).add_argument_to_parser(parser)
|
| 148 |
+
|
| 149 |
+
def _check_candidate_count_range(x: Any) -> int:
|
| 150 |
+
if x < _MIN_CANDIDATE_COUNT or x > _MAX_CANDIDATE_COUNT:
|
| 151 |
+
raise ValueError(
|
| 152 |
+
"Value should be in the range [{}, {}], got {}".format(
|
| 153 |
+
_MIN_CANDIDATE_COUNT, _MAX_CANDIDATE_COUNT, x
|
| 154 |
+
)
|
| 155 |
+
)
|
| 156 |
+
return int(x)
|
| 157 |
+
|
| 158 |
+
flag_def.SingleValueFlagDef(
|
| 159 |
+
name="candidate_count",
|
| 160 |
+
short_name="cc",
|
| 161 |
+
parse_type=int,
|
| 162 |
+
# Use None for default value to indicate that this will use the default
|
| 163 |
+
# value in Text service.
|
| 164 |
+
default_value=None,
|
| 165 |
+
parse_to_dest_type_fn=_check_candidate_count_range,
|
| 166 |
+
help_msg="The number of candidates to produce.",
|
| 167 |
+
).add_argument_to_parser(parser)
|
| 168 |
+
|
| 169 |
+
flag_def.BooleanFlagDef(
|
| 170 |
+
name="unique",
|
| 171 |
+
help_msg="Whether to dedupe candidates returned by the model.",
|
| 172 |
+
).add_argument_to_parser(parser)
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
def _add_input_flags(
|
| 176 |
+
parser: argparse.ArgumentParser,
|
| 177 |
+
placeholders: AbstractSet[str] | None,
|
| 178 |
+
) -> None:
|
| 179 |
+
"""Adds flags to read inputs from a Python variable or Sheets."""
|
| 180 |
+
flag_def.MultiValuesFlagDef(
|
| 181 |
+
name="inputs",
|
| 182 |
+
short_name="i",
|
| 183 |
+
dest_type=llmfn_inputs_source.LLMFnInputsSource,
|
| 184 |
+
parse_to_dest_type_fn=_get_resolve_input_from_py_var_fn(placeholders),
|
| 185 |
+
help_msg=(
|
| 186 |
+
"Optional names of Python variables containing inputs to use to"
|
| 187 |
+
" instantiate a prompt. The variable must be either: a dictionary"
|
| 188 |
+
" {'key1': ['val1', 'val2'] ...}, or an instance of LLMFnInputsSource"
|
| 189 |
+
" such as SheetsInput."
|
| 190 |
+
),
|
| 191 |
+
).add_argument_to_parser(parser)
|
| 192 |
+
|
| 193 |
+
flag_def.MultiValuesFlagDef(
|
| 194 |
+
name="sheets_input_names",
|
| 195 |
+
short_name="si",
|
| 196 |
+
dest_type=llmfn_inputs_source.LLMFnInputsSource,
|
| 197 |
+
parse_to_dest_type_fn=_get_resolve_sheets_inputs_fn(placeholders),
|
| 198 |
+
help_msg=(
|
| 199 |
+
"Optional names of Google Sheets to read inputs from. This is"
|
| 200 |
+
" equivalent to using --inputs with the names of variables that are"
|
| 201 |
+
" instances of SheetsInputs, just more convenient to use."
|
| 202 |
+
),
|
| 203 |
+
).add_argument_to_parser(parser)
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
def _add_output_flags(
|
| 207 |
+
parser: argparse.ArgumentParser,
|
| 208 |
+
) -> None:
|
| 209 |
+
"""Adds flags to write outputs to a Python variable."""
|
| 210 |
+
flag_def.MultiValuesFlagDef(
|
| 211 |
+
name="outputs",
|
| 212 |
+
short_name="o",
|
| 213 |
+
dest_type=llmfn_outputs.LLMFnOutputsSink,
|
| 214 |
+
parse_to_dest_type_fn=output_utils.get_outputs_sink_from_py_var,
|
| 215 |
+
help_msg=(
|
| 216 |
+
"Optional names of Python variables to output to. If the Python"
|
| 217 |
+
" variable has not already been defined, it will be created. If the"
|
| 218 |
+
" variable is defined and is an instance of LLMFnOutputsSink, the"
|
| 219 |
+
" outputs will be written through the sink's write_outputs() method."
|
| 220 |
+
),
|
| 221 |
+
).add_argument_to_parser(parser)
|
| 222 |
+
|
| 223 |
+
flag_def.MultiValuesFlagDef(
|
| 224 |
+
name="sheets_output_names",
|
| 225 |
+
short_name="so",
|
| 226 |
+
dest_type=llmfn_outputs.LLMFnOutputsSink,
|
| 227 |
+
parse_to_dest_type_fn=_resolve_sheets_outputs,
|
| 228 |
+
help_msg=(
|
| 229 |
+
"Optional names of Google Sheets to write inputs to. This is"
|
| 230 |
+
" equivalent to using --outputs with the names of variables that are"
|
| 231 |
+
" instances of SheetsOutputs, just more convenient to use."
|
| 232 |
+
),
|
| 233 |
+
).add_argument_to_parser(parser)
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
def _add_compare_flags(
|
| 237 |
+
parser: argparse.ArgumentParser,
|
| 238 |
+
) -> None:
|
| 239 |
+
flag_def.MultiValuesFlagDef(
|
| 240 |
+
name="compare_fn",
|
| 241 |
+
dest_type=tuple,
|
| 242 |
+
parse_to_dest_type_fn=_resolve_compare_fn_var,
|
| 243 |
+
help_msg=(
|
| 244 |
+
"An optional function that takes two inputs: (lhs_result, rhs_result)"
|
| 245 |
+
" which are the results of the left- and right-hand side functions. "
|
| 246 |
+
"Multiple comparison functions can be provided."
|
| 247 |
+
),
|
| 248 |
+
).add_argument_to_parser(parser)
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
def _add_eval_flags(
|
| 252 |
+
parser: argparse.ArgumentParser,
|
| 253 |
+
) -> None:
|
| 254 |
+
flag_def.SingleValueFlagDef(
|
| 255 |
+
name="ground_truth",
|
| 256 |
+
required=True,
|
| 257 |
+
dest_type=Sequence,
|
| 258 |
+
parse_to_dest_type_fn=_resolve_ground_truth_var,
|
| 259 |
+
help_msg=(
|
| 260 |
+
"A variable containing a Sequence of strings representing the ground"
|
| 261 |
+
" truth that the output of this cell will be compared against. It"
|
| 262 |
+
" should have the same number of entries as inputs."
|
| 263 |
+
),
|
| 264 |
+
).add_argument_to_parser(parser)
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
def _create_run_parser(
|
| 268 |
+
parser: argparse.ArgumentParser,
|
| 269 |
+
placeholders: AbstractSet[str] | None,
|
| 270 |
+
) -> None:
|
| 271 |
+
"""Adds flags for the `run` command.
|
| 272 |
+
|
| 273 |
+
`run` sends one or more prompts to a model.
|
| 274 |
+
|
| 275 |
+
Args:
|
| 276 |
+
parser: The parser to which flags will be added.
|
| 277 |
+
placeholders: Placeholders from prompts in the cell contents.
|
| 278 |
+
"""
|
| 279 |
+
_add_model_flags(parser)
|
| 280 |
+
_add_input_flags(parser, placeholders)
|
| 281 |
+
_add_output_flags(parser)
|
| 282 |
+
|
| 283 |
+
|
| 284 |
+
def _create_compile_parser(
|
| 285 |
+
parser: argparse.ArgumentParser,
|
| 286 |
+
) -> None:
|
| 287 |
+
"""Adds flags for the compile command.
|
| 288 |
+
|
| 289 |
+
`compile` "compiles" a prompt and model call into a callable function.
|
| 290 |
+
|
| 291 |
+
Args:
|
| 292 |
+
parser: The parser to which flags will be added.
|
| 293 |
+
"""
|
| 294 |
+
|
| 295 |
+
# Add a positional argument for "compile_save_name".
|
| 296 |
+
def _compile_save_name_fn(var_name: str) -> str:
|
| 297 |
+
try:
|
| 298 |
+
py_utils.validate_var_name(var_name)
|
| 299 |
+
except ValueError as e:
|
| 300 |
+
# Re-raise as ArgumentError to preserve the original error message.
|
| 301 |
+
raise argparse.ArgumentError(None, "{}".format(e)) from e
|
| 302 |
+
return var_name
|
| 303 |
+
|
| 304 |
+
save_name_help = "The name of a Python variable to save the compiled function to."
|
| 305 |
+
parser.add_argument("compile_save_name", help=save_name_help, type=_compile_save_name_fn)
|
| 306 |
+
_add_model_flags(parser)
|
| 307 |
+
|
| 308 |
+
|
| 309 |
+
def _create_compare_parser(
|
| 310 |
+
parser: argparse.ArgumentParser,
|
| 311 |
+
placeholders: AbstractSet[str] | None,
|
| 312 |
+
) -> None:
|
| 313 |
+
"""Adds flags for the compare command.
|
| 314 |
+
|
| 315 |
+
Args:
|
| 316 |
+
parser: The parser to which flags will be added.
|
| 317 |
+
placeholders: Placeholders from prompts in the compiled functions.
|
| 318 |
+
"""
|
| 319 |
+
|
| 320 |
+
# Add positional arguments.
|
| 321 |
+
def _resolve_llm_function_fn(
|
| 322 |
+
var_name: str,
|
| 323 |
+
) -> tuple[str, llm_function.LLMFunction]:
|
| 324 |
+
try:
|
| 325 |
+
py_utils.validate_var_name(var_name)
|
| 326 |
+
except ValueError as e:
|
| 327 |
+
# Re-raise as ArgumentError to preserve the original error message.
|
| 328 |
+
raise argparse.ArgumentError(None, "{}".format(e)) from e
|
| 329 |
+
|
| 330 |
+
fn = py_utils.get_py_var(var_name)
|
| 331 |
+
if not isinstance(fn, llm_function.LLMFunction):
|
| 332 |
+
raise argparse.ArgumentError(
|
| 333 |
+
None,
|
| 334 |
+
'{} is not a function created with the "compile" command'.format(var_name),
|
| 335 |
+
)
|
| 336 |
+
return var_name, fn
|
| 337 |
+
|
| 338 |
+
name_help = (
|
| 339 |
+
"The name of a Python variable containing a function previously created"
|
| 340 |
+
' with the "compile" command.'
|
| 341 |
+
)
|
| 342 |
+
parser.add_argument("lhs_name_and_fn", help=name_help, type=_resolve_llm_function_fn)
|
| 343 |
+
parser.add_argument("rhs_name_and_fn", help=name_help, type=_resolve_llm_function_fn)
|
| 344 |
+
|
| 345 |
+
_add_input_flags(parser, placeholders)
|
| 346 |
+
_add_output_flags(parser)
|
| 347 |
+
_add_compare_flags(parser)
|
| 348 |
+
|
| 349 |
+
|
| 350 |
+
def _create_eval_parser(
|
| 351 |
+
parser: argparse.ArgumentParser,
|
| 352 |
+
placeholders: AbstractSet[str] | None,
|
| 353 |
+
) -> None:
|
| 354 |
+
"""Adds flags for the eval command.
|
| 355 |
+
|
| 356 |
+
Args:
|
| 357 |
+
parser: The parser to which flags will be added.
|
| 358 |
+
placeholders: Placeholders from prompts in the cell contents.
|
| 359 |
+
"""
|
| 360 |
+
_add_model_flags(parser)
|
| 361 |
+
_add_input_flags(parser, placeholders)
|
| 362 |
+
_add_output_flags(parser)
|
| 363 |
+
_add_compare_flags(parser)
|
| 364 |
+
_add_eval_flags(parser)
|
| 365 |
+
|
| 366 |
+
|
| 367 |
+
def _create_parser(
|
| 368 |
+
placeholders: AbstractSet[str] | None,
|
| 369 |
+
) -> argparse.ArgumentParser:
|
| 370 |
+
"""Create the full parser."""
|
| 371 |
+
system_name = "llm"
|
| 372 |
+
description = "A system for interacting with LLMs."
|
| 373 |
+
epilog = ""
|
| 374 |
+
|
| 375 |
+
# Commands
|
| 376 |
+
parser = argument_parser.ArgumentParser(
|
| 377 |
+
prog=system_name,
|
| 378 |
+
description=description,
|
| 379 |
+
epilog=epilog,
|
| 380 |
+
)
|
| 381 |
+
subparsers = parser.add_subparsers(dest="cmd")
|
| 382 |
+
_create_run_parser(
|
| 383 |
+
subparsers.add_parser(parsed_args_lib.CommandName.RUN_CMD.value),
|
| 384 |
+
placeholders,
|
| 385 |
+
)
|
| 386 |
+
_create_compile_parser(subparsers.add_parser(parsed_args_lib.CommandName.COMPILE_CMD.value))
|
| 387 |
+
_create_compare_parser(
|
| 388 |
+
subparsers.add_parser(parsed_args_lib.CommandName.COMPARE_CMD.value),
|
| 389 |
+
placeholders,
|
| 390 |
+
)
|
| 391 |
+
_create_eval_parser(
|
| 392 |
+
subparsers.add_parser(parsed_args_lib.CommandName.EVAL_CMD.value),
|
| 393 |
+
placeholders,
|
| 394 |
+
)
|
| 395 |
+
return parser
|
| 396 |
+
|
| 397 |
+
|
| 398 |
+
def _validate_parsed_args(parsed_args: parsed_args_lib.ParsedArgs) -> None:
|
| 399 |
+
# If candidate_count is not set (i.e. is None), assuming the default value
|
| 400 |
+
# is 1.
|
| 401 |
+
if parsed_args.unique and (
|
| 402 |
+
parsed_args.model_args.candidate_count is None
|
| 403 |
+
or parsed_args.model_args.candidate_count == 1
|
| 404 |
+
):
|
| 405 |
+
print(
|
| 406 |
+
'"--unique" works across candidates only: it should be used with'
|
| 407 |
+
" --candidate_count set to a value greater-than one."
|
| 408 |
+
)
|
| 409 |
+
|
| 410 |
+
|
| 411 |
+
class CmdLineParser:
|
| 412 |
+
"""Implementation of Magics command line parser."""
|
| 413 |
+
|
| 414 |
+
# Commands
|
| 415 |
+
DEFAULT_CMD = parsed_args_lib.CommandName.RUN_CMD
|
| 416 |
+
|
| 417 |
+
# Post-processing operator.
|
| 418 |
+
PIPE_OP = "|"
|
| 419 |
+
|
| 420 |
+
@classmethod
|
| 421 |
+
def _split_post_processing_tokens(
|
| 422 |
+
cls,
|
| 423 |
+
tokens: Sequence[str],
|
| 424 |
+
) -> tuple[Sequence[str], parsed_args_lib.PostProcessingTokens]:
|
| 425 |
+
"""Splits inputs into the command and post processing tokens.
|
| 426 |
+
|
| 427 |
+
The command is represented as a sequence of tokens.
|
| 428 |
+
See comments on the PostProcessingTokens type alias.
|
| 429 |
+
|
| 430 |
+
E.g. Given: "run --temperature 0.5 | add_score | to_lower_case"
|
| 431 |
+
The command will be: ["run", "--temperature", "0.5"].
|
| 432 |
+
The post processing tokens will be: [["add_score"], ["to_lower_case"]]
|
| 433 |
+
|
| 434 |
+
Args:
|
| 435 |
+
tokens: The command line tokens.
|
| 436 |
+
|
| 437 |
+
Returns:
|
| 438 |
+
A tuple of (command line, post processing tokens).
|
| 439 |
+
"""
|
| 440 |
+
split_tokens = []
|
| 441 |
+
start_idx: int | None = None
|
| 442 |
+
for token_num, token in enumerate(tokens):
|
| 443 |
+
if start_idx is None:
|
| 444 |
+
start_idx = token_num
|
| 445 |
+
if token == CmdLineParser.PIPE_OP:
|
| 446 |
+
split_tokens.append(tokens[start_idx:token_num] if start_idx is not None else [])
|
| 447 |
+
start_idx = None
|
| 448 |
+
|
| 449 |
+
# Add the remaining tokens after the last PIPE_OP.
|
| 450 |
+
split_tokens.append(tokens[start_idx:] if start_idx is not None else [])
|
| 451 |
+
|
| 452 |
+
return split_tokens[0], split_tokens[1:]
|
| 453 |
+
|
| 454 |
+
@classmethod
|
| 455 |
+
def _tokenize_line(
|
| 456 |
+
cls, line: str
|
| 457 |
+
) -> tuple[Sequence[str], parsed_args_lib.PostProcessingTokens]:
|
| 458 |
+
"""Parses `line` and returns command line and post processing tokens."""
|
| 459 |
+
# Check to make sure there is a command at the start. If not, add the
|
| 460 |
+
# default command to the list of tokens.
|
| 461 |
+
tokens = shlex.split(line)
|
| 462 |
+
if not tokens:
|
| 463 |
+
tokens = [CmdLineParser.DEFAULT_CMD.value]
|
| 464 |
+
first_token = tokens[0]
|
| 465 |
+
# Add default command if the first token is not the help token.
|
| 466 |
+
if not first_token[0].isalpha() and first_token not in ["-h", "--help"]:
|
| 467 |
+
tokens = [CmdLineParser.DEFAULT_CMD.value] + tokens
|
| 468 |
+
# Split line into tokens and post-processing
|
| 469 |
+
return CmdLineParser._split_post_processing_tokens(tokens)
|
| 470 |
+
|
| 471 |
+
@classmethod
|
| 472 |
+
def _get_model_args(
|
| 473 |
+
cls, parsed_results: MutableMapping[str, Any]
|
| 474 |
+
) -> tuple[MutableMapping[str, Any], model_lib.ModelArguments]:
|
| 475 |
+
"""Extracts fields for model args from `parsed_results`.
|
| 476 |
+
|
| 477 |
+
Keys specific to model arguments will be removed from `parsed_results`.
|
| 478 |
+
|
| 479 |
+
Args:
|
| 480 |
+
parsed_results: A dictionary of parsed arguments (from ArgumentParser). It
|
| 481 |
+
will be modified in place.
|
| 482 |
+
|
| 483 |
+
Returns:
|
| 484 |
+
A tuple of (updated parsed_results, model arguments).
|
| 485 |
+
"""
|
| 486 |
+
model = parsed_results.pop("model", None)
|
| 487 |
+
temperature = parsed_results.pop("temperature", None)
|
| 488 |
+
candidate_count = parsed_results.pop("candidate_count", None)
|
| 489 |
+
|
| 490 |
+
model_args = model_lib.ModelArguments(
|
| 491 |
+
model=model,
|
| 492 |
+
temperature=temperature,
|
| 493 |
+
candidate_count=candidate_count,
|
| 494 |
+
)
|
| 495 |
+
return parsed_results, model_args
|
| 496 |
+
|
| 497 |
+
def parse_line(
|
| 498 |
+
self,
|
| 499 |
+
line: str,
|
| 500 |
+
placeholders: AbstractSet[str] | None = None,
|
| 501 |
+
) -> tuple[parsed_args_lib.ParsedArgs, parsed_args_lib.PostProcessingTokens]:
|
| 502 |
+
"""Parses the commandline and returns ParsedArgs and post-processing tokens.
|
| 503 |
+
|
| 504 |
+
Args:
|
| 505 |
+
line: The line to parse (usually contents from cell Magics).
|
| 506 |
+
placeholders: Placeholders from prompts in the cell contents.
|
| 507 |
+
|
| 508 |
+
Returns:
|
| 509 |
+
A tuple of (parsed_args, post_processing_tokens).
|
| 510 |
+
"""
|
| 511 |
+
tokens, post_processing_tokens = CmdLineParser._tokenize_line(line)
|
| 512 |
+
|
| 513 |
+
parsed_args = self._get_parsed_args_from_cmd_line_tokens(
|
| 514 |
+
tokens=tokens, placeholders=placeholders
|
| 515 |
+
)
|
| 516 |
+
|
| 517 |
+
# Special-case for "compare" command: because the prompts are compiled into
|
| 518 |
+
# the left- and right-hand side functions rather than in the cell body, we
|
| 519 |
+
# cannot examine the cell body to get the placeholders.
|
| 520 |
+
#
|
| 521 |
+
# Instead we parse the command line twice: once to get the left- and right-
|
| 522 |
+
# functions, then we query the functions for their placeholders, then
|
| 523 |
+
# parse the commandline again to validate the inputs.
|
| 524 |
+
if parsed_args.cmd == parsed_args_lib.CommandName.COMPARE_CMD:
|
| 525 |
+
assert parsed_args.lhs_name_and_fn is not None
|
| 526 |
+
assert parsed_args.rhs_name_and_fn is not None
|
| 527 |
+
_, lhs_fn = parsed_args.lhs_name_and_fn
|
| 528 |
+
_, rhs_fn = parsed_args.rhs_name_and_fn
|
| 529 |
+
parsed_args = self._get_parsed_args_from_cmd_line_tokens(
|
| 530 |
+
tokens=tokens,
|
| 531 |
+
placeholders=frozenset(lhs_fn.get_placeholders()).union(rhs_fn.get_placeholders()),
|
| 532 |
+
)
|
| 533 |
+
|
| 534 |
+
_validate_parsed_args(parsed_args)
|
| 535 |
+
|
| 536 |
+
for expr in post_processing_tokens:
|
| 537 |
+
post_process_utils.validate_one_post_processing_expression(expr)
|
| 538 |
+
|
| 539 |
+
return parsed_args, post_processing_tokens
|
| 540 |
+
|
| 541 |
+
def _get_parsed_args_from_cmd_line_tokens(
|
| 542 |
+
self,
|
| 543 |
+
tokens: Sequence[str],
|
| 544 |
+
placeholders: AbstractSet[str] | None,
|
| 545 |
+
) -> parsed_args_lib.ParsedArgs:
|
| 546 |
+
"""Returns ParsedArgs from a tokenized command line."""
|
| 547 |
+
# Create a new parser to avoid reusing the temporary argparse.Namespace
|
| 548 |
+
# object.
|
| 549 |
+
results = _create_parser(placeholders).parse_args(tokens)
|
| 550 |
+
|
| 551 |
+
results_dict = vars(results)
|
| 552 |
+
results_dict["cmd"] = parsed_args_lib.CommandName(results_dict["cmd"])
|
| 553 |
+
|
| 554 |
+
results_dict, model_args = CmdLineParser._get_model_args(results_dict)
|
| 555 |
+
results_dict["model_args"] = model_args
|
| 556 |
+
|
| 557 |
+
return parsed_args_lib.ParsedArgs(**results_dict)
|
.venv/lib/python3.11/site-packages/google/generativeai/notebook/command.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Copyright 2023 Google LLC
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Command."""
|
| 16 |
+
from __future__ import annotations
|
| 17 |
+
|
| 18 |
+
import abc
|
| 19 |
+
import collections
|
| 20 |
+
from typing import Sequence
|
| 21 |
+
|
| 22 |
+
from google.generativeai.notebook import parsed_args_lib
|
| 23 |
+
from google.generativeai.notebook import post_process_utils
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
ProcessingCommand = collections.namedtuple("ProcessingCommand", ["name", "fn"])
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class Command(abc.ABC):
|
| 30 |
+
"""Base class for implementation of Magics commands like "run"."""
|
| 31 |
+
|
| 32 |
+
@abc.abstractmethod
|
| 33 |
+
def execute(
|
| 34 |
+
self,
|
| 35 |
+
parsed_args: parsed_args_lib.ParsedArgs,
|
| 36 |
+
cell_content: str,
|
| 37 |
+
post_processing_fns: Sequence[post_process_utils.ParsedPostProcessExpr],
|
| 38 |
+
):
|
| 39 |
+
"""Executes the command given `parsed_args` and the `cell_content`."""
|
| 40 |
+
|
| 41 |
+
@abc.abstractmethod
|
| 42 |
+
def parse_post_processing_tokens(
|
| 43 |
+
self, tokens: Sequence[Sequence[str]]
|
| 44 |
+
) -> Sequence[post_process_utils.ParsedPostProcessExpr]:
|
| 45 |
+
"""Parses post-processing tokens for this command."""
|
.venv/lib/python3.11/site-packages/google/generativeai/notebook/compare_cmd.py
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Copyright 2023 Google LLC
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""The compare command."""
|
| 16 |
+
from __future__ import annotations
|
| 17 |
+
|
| 18 |
+
from typing import Sequence
|
| 19 |
+
|
| 20 |
+
from google.generativeai.notebook import command
|
| 21 |
+
from google.generativeai.notebook import command_utils
|
| 22 |
+
from google.generativeai.notebook import input_utils
|
| 23 |
+
from google.generativeai.notebook import ipython_env
|
| 24 |
+
from google.generativeai.notebook import output_utils
|
| 25 |
+
from google.generativeai.notebook import parsed_args_lib
|
| 26 |
+
from google.generativeai.notebook import post_process_utils
|
| 27 |
+
import pandas
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class CompareCommand(command.Command):
|
| 31 |
+
"""Implementation of "compare" command."""
|
| 32 |
+
|
| 33 |
+
def __init__(
|
| 34 |
+
self,
|
| 35 |
+
env: ipython_env.IPythonEnv | None = None,
|
| 36 |
+
):
|
| 37 |
+
"""Constructor.
|
| 38 |
+
|
| 39 |
+
Args:
|
| 40 |
+
env: The IPythonEnv environment.
|
| 41 |
+
"""
|
| 42 |
+
super().__init__()
|
| 43 |
+
self._ipython_env = env
|
| 44 |
+
|
| 45 |
+
def execute(
|
| 46 |
+
self,
|
| 47 |
+
parsed_args: parsed_args_lib.ParsedArgs,
|
| 48 |
+
cell_content: str,
|
| 49 |
+
post_processing_fns: Sequence[post_process_utils.ParsedPostProcessExpr],
|
| 50 |
+
) -> pandas.DataFrame:
|
| 51 |
+
# We expect CmdLineParser to have already read the inputs once to validate
|
| 52 |
+
# that the placeholders in the prompt are present in the inputs, so we can
|
| 53 |
+
# suppress the status messages here.
|
| 54 |
+
inputs = input_utils.join_inputs_sources(parsed_args, suppress_status_msgs=True)
|
| 55 |
+
|
| 56 |
+
llm_cmp_fn = command_utils.create_llm_compare_function(
|
| 57 |
+
env=self._ipython_env,
|
| 58 |
+
parsed_args=parsed_args,
|
| 59 |
+
post_processing_fns=post_processing_fns,
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
results = llm_cmp_fn(inputs=inputs)
|
| 63 |
+
output_utils.write_to_outputs(results=results, parsed_args=parsed_args)
|
| 64 |
+
return results.as_pandas_dataframe()
|
| 65 |
+
|
| 66 |
+
def parse_post_processing_tokens(
|
| 67 |
+
self, tokens: Sequence[Sequence[str]]
|
| 68 |
+
) -> Sequence[post_process_utils.ParsedPostProcessExpr]:
|
| 69 |
+
if tokens:
|
| 70 |
+
raise RuntimeError('Post-processing is not supported by "compare"')
|
| 71 |
+
return []
|
.venv/lib/python3.11/site-packages/google/generativeai/notebook/compile_cmd.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Copyright 2023 Google LLC
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""The compile command."""
|
| 16 |
+
from __future__ import annotations
|
| 17 |
+
|
| 18 |
+
from typing import Sequence
|
| 19 |
+
from google.generativeai.notebook import command
|
| 20 |
+
from google.generativeai.notebook import command_utils
|
| 21 |
+
from google.generativeai.notebook import ipython_env
|
| 22 |
+
from google.generativeai.notebook import model_registry
|
| 23 |
+
from google.generativeai.notebook import parsed_args_lib
|
| 24 |
+
from google.generativeai.notebook import post_process_utils
|
| 25 |
+
from google.generativeai.notebook import py_utils
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class CompileCommand(command.Command):
|
| 29 |
+
"""Implementation of the "compile" command."""
|
| 30 |
+
|
| 31 |
+
def __init__(
|
| 32 |
+
self,
|
| 33 |
+
models: model_registry.ModelRegistry,
|
| 34 |
+
env: ipython_env.IPythonEnv | None = None,
|
| 35 |
+
):
|
| 36 |
+
"""Constructor.
|
| 37 |
+
|
| 38 |
+
Args:
|
| 39 |
+
models: ModelRegistry instance.
|
| 40 |
+
env: The IPythonEnv environment.
|
| 41 |
+
"""
|
| 42 |
+
super().__init__()
|
| 43 |
+
self._models = models
|
| 44 |
+
self._ipython_env = env
|
| 45 |
+
|
| 46 |
+
def execute(
|
| 47 |
+
self,
|
| 48 |
+
parsed_args: parsed_args_lib.ParsedArgs,
|
| 49 |
+
cell_content: str,
|
| 50 |
+
post_processing_fns: Sequence[post_process_utils.ParsedPostProcessExpr],
|
| 51 |
+
) -> str:
|
| 52 |
+
llm_fn = command_utils.create_llm_function(
|
| 53 |
+
models=self._models,
|
| 54 |
+
env=self._ipython_env,
|
| 55 |
+
parsed_args=parsed_args,
|
| 56 |
+
cell_content=cell_content,
|
| 57 |
+
post_processing_fns=post_processing_fns,
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
py_utils.set_py_var(parsed_args.compile_save_name, llm_fn)
|
| 61 |
+
return "Saved function to Python variable: {}".format(parsed_args.compile_save_name)
|
| 62 |
+
|
| 63 |
+
def parse_post_processing_tokens(
|
| 64 |
+
self, tokens: Sequence[Sequence[str]]
|
| 65 |
+
) -> Sequence[post_process_utils.ParsedPostProcessExpr]:
|
| 66 |
+
return post_process_utils.resolve_post_processing_tokens(tokens)
|
.venv/lib/python3.11/site-packages/google/generativeai/notebook/eval_cmd.py
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Copyright 2023 Google LLC
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""The eval command."""
|
| 16 |
+
from __future__ import annotations
|
| 17 |
+
|
| 18 |
+
from typing import Sequence
|
| 19 |
+
|
| 20 |
+
from google.generativeai.notebook import command
|
| 21 |
+
from google.generativeai.notebook import command_utils
|
| 22 |
+
from google.generativeai.notebook import input_utils
|
| 23 |
+
from google.generativeai.notebook import ipython_env
|
| 24 |
+
from google.generativeai.notebook import model_registry
|
| 25 |
+
from google.generativeai.notebook import output_utils
|
| 26 |
+
from google.generativeai.notebook import parsed_args_lib
|
| 27 |
+
from google.generativeai.notebook import post_process_utils
|
| 28 |
+
import pandas
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class EvalCommand(command.Command):
|
| 32 |
+
"""Implementation of "eval" command."""
|
| 33 |
+
|
| 34 |
+
def __init__(
|
| 35 |
+
self,
|
| 36 |
+
models: model_registry.ModelRegistry,
|
| 37 |
+
env: ipython_env.IPythonEnv | None = None,
|
| 38 |
+
):
|
| 39 |
+
"""Constructor.
|
| 40 |
+
|
| 41 |
+
Args:
|
| 42 |
+
models: ModelRegistry instance.
|
| 43 |
+
env: The IPythonEnv environment.
|
| 44 |
+
"""
|
| 45 |
+
super().__init__()
|
| 46 |
+
self._models = models
|
| 47 |
+
self._ipython_env = env
|
| 48 |
+
|
| 49 |
+
def execute(
|
| 50 |
+
self,
|
| 51 |
+
parsed_args: parsed_args_lib.ParsedArgs,
|
| 52 |
+
cell_content: str,
|
| 53 |
+
post_processing_fns: Sequence[post_process_utils.ParsedPostProcessExpr],
|
| 54 |
+
) -> pandas.DataFrame:
|
| 55 |
+
# We expect CmdLineParser to have already read the inputs once to validate
|
| 56 |
+
# that the placeholders in the prompt are present in the inputs, so we can
|
| 57 |
+
# suppress the status messages here.
|
| 58 |
+
inputs = input_utils.join_inputs_sources(parsed_args, suppress_status_msgs=True)
|
| 59 |
+
|
| 60 |
+
llm_cmp_fn = command_utils.create_llm_eval_function(
|
| 61 |
+
models=self._models,
|
| 62 |
+
env=self._ipython_env,
|
| 63 |
+
parsed_args=parsed_args,
|
| 64 |
+
cell_content=cell_content,
|
| 65 |
+
post_processing_fns=post_processing_fns,
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
results = llm_cmp_fn(inputs=inputs)
|
| 69 |
+
output_utils.write_to_outputs(results=results, parsed_args=parsed_args)
|
| 70 |
+
return results.as_pandas_dataframe()
|
| 71 |
+
|
| 72 |
+
def parse_post_processing_tokens(
|
| 73 |
+
self, tokens: Sequence[Sequence[str]]
|
| 74 |
+
) -> Sequence[post_process_utils.ParsedPostProcessExpr]:
|
| 75 |
+
return post_process_utils.resolve_post_processing_tokens(tokens)
|
.venv/lib/python3.11/site-packages/google/generativeai/notebook/flag_def.py
ADDED
|
@@ -0,0 +1,503 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Copyright 2023 Google LLC
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Classes that define arguments for populating ArgumentParser.
|
| 16 |
+
|
| 17 |
+
The argparse module's ArgumentParser.add_argument() takes several parameters and
|
| 18 |
+
is quite customizable. However this can lead to bugs where arguments do not
|
| 19 |
+
behave as expected.
|
| 20 |
+
|
| 21 |
+
For better ease-of-use and better testability, define a set of classes for the
|
| 22 |
+
types of flags used by LLM Magics.
|
| 23 |
+
|
| 24 |
+
Sample usage:
|
| 25 |
+
|
| 26 |
+
str_flag = SingleValueFlagDef(name="title", required=True)
|
| 27 |
+
enum_flag = EnumFlagDef(name="colors", required=True, enum_type=ColorsEnum)
|
| 28 |
+
|
| 29 |
+
str_flag.add_argument_to_parser(my_parser)
|
| 30 |
+
enum_flag.add_argument_to_parser(my_parser)
|
| 31 |
+
"""
|
| 32 |
+
from __future__ import annotations
|
| 33 |
+
|
| 34 |
+
import abc
|
| 35 |
+
import argparse
|
| 36 |
+
import dataclasses
|
| 37 |
+
import enum
|
| 38 |
+
from typing import Any, Callable, Sequence, Tuple, Union
|
| 39 |
+
|
| 40 |
+
from google.generativeai.notebook.lib import llmfn_inputs_source
|
| 41 |
+
from google.generativeai.notebook.lib import llmfn_outputs
|
| 42 |
+
|
| 43 |
+
# These are the intermediate types that argparse.ArgumentParser.parse_args()
|
| 44 |
+
# will pass command line arguments into.
|
| 45 |
+
_PARSETYPES = Union[str, int, float]
|
| 46 |
+
# These are the final result types that the intermediate parsed values will be
|
| 47 |
+
# converted into. It is a superset of _PARSETYPES because we support converting
|
| 48 |
+
# the parsed type into a more precise type, e.g. from str to Enum.
|
| 49 |
+
_DESTTYPES = Union[
|
| 50 |
+
_PARSETYPES,
|
| 51 |
+
enum.Enum,
|
| 52 |
+
Tuple[str, Callable[[str, str], Any]],
|
| 53 |
+
Sequence[str], # For --compare_fn
|
| 54 |
+
llmfn_inputs_source.LLMFnInputsSource, # For --ground_truth
|
| 55 |
+
llmfn_outputs.LLMFnOutputsSink, # For --inputs # For --outputs
|
| 56 |
+
]
|
| 57 |
+
|
| 58 |
+
# The signature of a function that converts a command line argument from the
|
| 59 |
+
# intermediate parsed type to the result type.
|
| 60 |
+
_PARSEFN = Callable[[_PARSETYPES], _DESTTYPES]
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def _get_type_name(x: type[Any]) -> str:
|
| 64 |
+
try:
|
| 65 |
+
return x.__name__
|
| 66 |
+
except AttributeError:
|
| 67 |
+
return str(x)
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def _validate_flag_name(name: str) -> str:
|
| 71 |
+
"""Validation for long and short names for flags."""
|
| 72 |
+
if not name:
|
| 73 |
+
raise ValueError("Cannot be empty")
|
| 74 |
+
if name[0] == "-":
|
| 75 |
+
raise ValueError("Cannot start with dash")
|
| 76 |
+
return name
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
@dataclasses.dataclass(frozen=True)
|
| 80 |
+
class FlagDef(abc.ABC):
|
| 81 |
+
"""Abstract base class for flag definitions.
|
| 82 |
+
|
| 83 |
+
Attributes:
|
| 84 |
+
name: Long name, e.g. "colors" will define the flag "--colors".
|
| 85 |
+
required: Whether the flag must be provided on the command line.
|
| 86 |
+
short_name: Optional short name.
|
| 87 |
+
parse_type: The type that ArgumentParser should parse the command line
|
| 88 |
+
argument to.
|
| 89 |
+
dest_type: The type that the parsed value is converted to. This is used when
|
| 90 |
+
we want ArgumentParser to parse as one type, then convert to a different
|
| 91 |
+
type. E.g. for enums we parse as "str" then convert to the desired enum
|
| 92 |
+
type in order to provide cleaner help messages.
|
| 93 |
+
parse_to_dest_type_fn: If provided, this function will be used to convert
|
| 94 |
+
the value from `parse_type` to `dest_type`. This can be used for
|
| 95 |
+
validation as well.
|
| 96 |
+
choices: If provided, limit the set of acceptable values to these choices.
|
| 97 |
+
help_msg: If provided, adds help message when -h is used in the command
|
| 98 |
+
line.
|
| 99 |
+
"""
|
| 100 |
+
|
| 101 |
+
name: str
|
| 102 |
+
required: bool = False
|
| 103 |
+
|
| 104 |
+
short_name: str | None = None
|
| 105 |
+
|
| 106 |
+
parse_type: type[_PARSETYPES] = str
|
| 107 |
+
dest_type: type[_DESTTYPES] | None = None
|
| 108 |
+
parse_to_dest_type_fn: _PARSEFN | None = None
|
| 109 |
+
|
| 110 |
+
choices: list[_PARSETYPES] | None = None
|
| 111 |
+
help_msg: str | None = None
|
| 112 |
+
|
| 113 |
+
@abc.abstractmethod
|
| 114 |
+
def add_argument_to_parser(self, parser: argparse.ArgumentParser) -> None:
|
| 115 |
+
"""Adds this flag as an argument to `parser`.
|
| 116 |
+
|
| 117 |
+
Child classes should implement this as a call to parser.add_argument()
|
| 118 |
+
with the appropriate parameters.
|
| 119 |
+
|
| 120 |
+
Args:
|
| 121 |
+
parser: The parser to which this argument will be added.
|
| 122 |
+
"""
|
| 123 |
+
|
| 124 |
+
@abc.abstractmethod
|
| 125 |
+
def _do_additional_validation(self) -> None:
|
| 126 |
+
"""For child classes to do additional validation."""
|
| 127 |
+
|
| 128 |
+
def _get_dest_type(self) -> type[_DESTTYPES]:
|
| 129 |
+
"""Returns the final converted type."""
|
| 130 |
+
return self.parse_type if self.dest_type is None else self.dest_type
|
| 131 |
+
|
| 132 |
+
def _get_parse_to_dest_type_fn(
|
| 133 |
+
self,
|
| 134 |
+
) -> _PARSEFN:
|
| 135 |
+
"""Returns a function to convert from parse_type to dest_type."""
|
| 136 |
+
if self.parse_to_dest_type_fn is not None:
|
| 137 |
+
return self.parse_to_dest_type_fn
|
| 138 |
+
|
| 139 |
+
dest_type = self._get_dest_type()
|
| 140 |
+
if dest_type == self.parse_type:
|
| 141 |
+
return lambda x: x
|
| 142 |
+
else:
|
| 143 |
+
return dest_type
|
| 144 |
+
|
| 145 |
+
def __post_init__(self):
|
| 146 |
+
_validate_flag_name(self.name)
|
| 147 |
+
if self.short_name is not None:
|
| 148 |
+
_validate_flag_name(self.short_name)
|
| 149 |
+
|
| 150 |
+
self._do_additional_validation()
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
def _has_non_default_value(
|
| 154 |
+
namespace: argparse.Namespace,
|
| 155 |
+
dest: str,
|
| 156 |
+
has_default: bool = False,
|
| 157 |
+
default_value: Any = None,
|
| 158 |
+
) -> bool:
|
| 159 |
+
"""Returns true if `namespace.dest` is set to a non-default value.
|
| 160 |
+
|
| 161 |
+
Args:
|
| 162 |
+
namespace: The Namespace that is populated by ArgumentParser.
|
| 163 |
+
dest: The attribute in the Namespace to be populated.
|
| 164 |
+
has_default: "None" is a valid default value so we use an additional
|
| 165 |
+
`has_default` boolean to indicate that `default_value` is present.
|
| 166 |
+
default_value: The default value to use when `has_default` is True.
|
| 167 |
+
|
| 168 |
+
Returns:
|
| 169 |
+
Whether namespace.dest is set to something other than the default value.
|
| 170 |
+
"""
|
| 171 |
+
if not hasattr(namespace, dest):
|
| 172 |
+
return False
|
| 173 |
+
|
| 174 |
+
if not has_default:
|
| 175 |
+
# No default value provided so `namespace.dest` cannot possibly be equal to
|
| 176 |
+
# the default value.
|
| 177 |
+
return True
|
| 178 |
+
|
| 179 |
+
return getattr(namespace, dest) != default_value
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
class _SingleValueStoreAction(argparse.Action):
|
| 183 |
+
"""Custom Action for storing a value in an argparse.Namespace.
|
| 184 |
+
|
| 185 |
+
This action checks that the flag is specified at-most once.
|
| 186 |
+
"""
|
| 187 |
+
|
| 188 |
+
def __init__(
|
| 189 |
+
self,
|
| 190 |
+
option_strings,
|
| 191 |
+
dest,
|
| 192 |
+
dest_type: type[Any],
|
| 193 |
+
parse_to_dest_type_fn: _PARSEFN,
|
| 194 |
+
**kwargs,
|
| 195 |
+
):
|
| 196 |
+
super().__init__(option_strings, dest, **kwargs)
|
| 197 |
+
self._dest_type = dest_type
|
| 198 |
+
self._parse_to_dest_type_fn = parse_to_dest_type_fn
|
| 199 |
+
|
| 200 |
+
def __call__(
|
| 201 |
+
self,
|
| 202 |
+
parser: argparse.ArgumentParser,
|
| 203 |
+
namespace: argparse.Namespace,
|
| 204 |
+
values: str | Sequence[Any] | None,
|
| 205 |
+
option_string: str | None = None,
|
| 206 |
+
):
|
| 207 |
+
# Because `nargs` is set to 1, `values` must be a Sequence, rather
|
| 208 |
+
# than a string.
|
| 209 |
+
assert not isinstance(values, str) and not isinstance(values, bytes)
|
| 210 |
+
|
| 211 |
+
if _has_non_default_value(
|
| 212 |
+
namespace,
|
| 213 |
+
self.dest,
|
| 214 |
+
has_default=hasattr(self, "default"),
|
| 215 |
+
default_value=getattr(self, "default"),
|
| 216 |
+
):
|
| 217 |
+
raise argparse.ArgumentError(self, "Cannot set {} more than once".format(option_string))
|
| 218 |
+
|
| 219 |
+
try:
|
| 220 |
+
converted_value = self._parse_to_dest_type_fn(values[0])
|
| 221 |
+
except Exception as e:
|
| 222 |
+
raise argparse.ArgumentError(
|
| 223 |
+
self,
|
| 224 |
+
'Error with value "{}", got {}: {}'.format(values[0], _get_type_name(type(e)), e),
|
| 225 |
+
)
|
| 226 |
+
|
| 227 |
+
if not isinstance(converted_value, self._dest_type):
|
| 228 |
+
raise RuntimeError(
|
| 229 |
+
"Converted to wrong type, expected {} got {}".format(
|
| 230 |
+
_get_type_name(self._dest_type),
|
| 231 |
+
_get_type_name(type(converted_value)),
|
| 232 |
+
)
|
| 233 |
+
)
|
| 234 |
+
setattr(namespace, self.dest, converted_value)
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
class _MultiValuesAppendAction(argparse.Action):
|
| 238 |
+
"""Custom Action for appending values in an argparse.Namespace.
|
| 239 |
+
|
| 240 |
+
This action checks that the flag is specified at-most once.
|
| 241 |
+
"""
|
| 242 |
+
|
| 243 |
+
def __init__(
|
| 244 |
+
self,
|
| 245 |
+
option_strings,
|
| 246 |
+
dest,
|
| 247 |
+
dest_type: type[Any],
|
| 248 |
+
parse_to_dest_type_fn: _PARSEFN,
|
| 249 |
+
**kwargs,
|
| 250 |
+
):
|
| 251 |
+
super().__init__(option_strings, dest, **kwargs)
|
| 252 |
+
self._dest_type = dest_type
|
| 253 |
+
self._parse_to_dest_type_fn = parse_to_dest_type_fn
|
| 254 |
+
|
| 255 |
+
def __call__(
|
| 256 |
+
self,
|
| 257 |
+
parser: argparse.ArgumentParser,
|
| 258 |
+
namespace: argparse.Namespace,
|
| 259 |
+
values: str | Sequence[Any] | None,
|
| 260 |
+
option_string: str | None = None,
|
| 261 |
+
):
|
| 262 |
+
# Because `nargs` is set to "+", `values` must be a Sequence, rather
|
| 263 |
+
# than a string.
|
| 264 |
+
assert not isinstance(values, str) and not isinstance(values, bytes)
|
| 265 |
+
|
| 266 |
+
curr_value = getattr(namespace, self.dest)
|
| 267 |
+
if curr_value:
|
| 268 |
+
raise argparse.ArgumentError(self, "Cannot set {} more than once".format(option_string))
|
| 269 |
+
|
| 270 |
+
for value in values:
|
| 271 |
+
try:
|
| 272 |
+
converted_value = self._parse_to_dest_type_fn(value)
|
| 273 |
+
except Exception as e:
|
| 274 |
+
raise argparse.ArgumentError(
|
| 275 |
+
self,
|
| 276 |
+
'Error with value "{}", got {}: {}'.format(
|
| 277 |
+
values[0], _get_type_name(type(e)), e
|
| 278 |
+
),
|
| 279 |
+
)
|
| 280 |
+
|
| 281 |
+
if not isinstance(converted_value, self._dest_type):
|
| 282 |
+
raise RuntimeError(
|
| 283 |
+
"Converted to wrong type, expected {} got {}".format(
|
| 284 |
+
self._dest_type, type(converted_value)
|
| 285 |
+
)
|
| 286 |
+
)
|
| 287 |
+
if converted_value in curr_value:
|
| 288 |
+
raise argparse.ArgumentError(self, 'Duplicate values "{}"'.format(value))
|
| 289 |
+
|
| 290 |
+
curr_value.append(converted_value)
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
class _BooleanValueStoreAction(argparse.Action):
|
| 294 |
+
"""Custom Action for setting a boolean value in argparse.Namespace.
|
| 295 |
+
|
| 296 |
+
The boolean flag expects the default to be False and will set the value to
|
| 297 |
+
True.
|
| 298 |
+
This action checks that the flag is specified at-most once.
|
| 299 |
+
"""
|
| 300 |
+
|
| 301 |
+
def __init__(
|
| 302 |
+
self,
|
| 303 |
+
option_strings,
|
| 304 |
+
dest,
|
| 305 |
+
**kwargs,
|
| 306 |
+
):
|
| 307 |
+
super().__init__(option_strings, dest, **kwargs)
|
| 308 |
+
|
| 309 |
+
def __call__(
|
| 310 |
+
self,
|
| 311 |
+
parser: argparse.ArgumentParser,
|
| 312 |
+
namespace: argparse.Namespace,
|
| 313 |
+
values: str | Sequence[Any] | None,
|
| 314 |
+
option_string: str | None = None,
|
| 315 |
+
):
|
| 316 |
+
if _has_non_default_value(
|
| 317 |
+
namespace,
|
| 318 |
+
self.dest,
|
| 319 |
+
has_default=True,
|
| 320 |
+
default_value=False,
|
| 321 |
+
):
|
| 322 |
+
raise argparse.ArgumentError(self, "Cannot set {} more than once".format(option_string))
|
| 323 |
+
|
| 324 |
+
setattr(namespace, self.dest, True)
|
| 325 |
+
|
| 326 |
+
|
| 327 |
+
@dataclasses.dataclass(frozen=True)
|
| 328 |
+
class SingleValueFlagDef(FlagDef):
|
| 329 |
+
"""Definition for a flag that takes a single value.
|
| 330 |
+
|
| 331 |
+
Sample usage:
|
| 332 |
+
# This defines a flag that can be specified on the command line as:
|
| 333 |
+
# --count=10
|
| 334 |
+
flag = SingleValueFlagDef(name="count", parse_type=int, required=True)
|
| 335 |
+
flag.add_argument_to_parser(argument_parser)
|
| 336 |
+
|
| 337 |
+
Attributes:
|
| 338 |
+
default_value: Default value for optional flags.
|
| 339 |
+
"""
|
| 340 |
+
|
| 341 |
+
class _DefaultValue(enum.Enum):
|
| 342 |
+
"""Special value to represent "no value provided".
|
| 343 |
+
|
| 344 |
+
"None" can be used as a default value, so in order to differentiate between
|
| 345 |
+
"None" and "no value provided", create a special value for "no value
|
| 346 |
+
provided".
|
| 347 |
+
"""
|
| 348 |
+
|
| 349 |
+
NOT_SET = None
|
| 350 |
+
|
| 351 |
+
default_value: _DESTTYPES | _DefaultValue | None = _DefaultValue.NOT_SET
|
| 352 |
+
|
| 353 |
+
def _has_default_value(self) -> bool:
|
| 354 |
+
"""Returns whether `default_value` has been provided."""
|
| 355 |
+
return self.default_value != SingleValueFlagDef._DefaultValue.NOT_SET
|
| 356 |
+
|
| 357 |
+
def add_argument_to_parser(self, parser: argparse.ArgumentParser) -> None:
|
| 358 |
+
args = ["--" + self.name]
|
| 359 |
+
if self.short_name is not None:
|
| 360 |
+
args += ["-" + self.short_name]
|
| 361 |
+
|
| 362 |
+
kwargs = {}
|
| 363 |
+
if self._has_default_value():
|
| 364 |
+
kwargs["default"] = self.default_value
|
| 365 |
+
if self.choices is not None:
|
| 366 |
+
kwargs["choices"] = self.choices
|
| 367 |
+
if self.help_msg is not None:
|
| 368 |
+
kwargs["help"] = self.help_msg
|
| 369 |
+
|
| 370 |
+
parser.add_argument(
|
| 371 |
+
*args,
|
| 372 |
+
action=_SingleValueStoreAction,
|
| 373 |
+
type=self.parse_type,
|
| 374 |
+
dest_type=self._get_dest_type(),
|
| 375 |
+
parse_to_dest_type_fn=self._get_parse_to_dest_type_fn(),
|
| 376 |
+
required=self.required,
|
| 377 |
+
nargs=1,
|
| 378 |
+
**kwargs,
|
| 379 |
+
)
|
| 380 |
+
|
| 381 |
+
def _do_additional_validation(self) -> None:
|
| 382 |
+
if self.required:
|
| 383 |
+
if self._has_default_value():
|
| 384 |
+
raise ValueError("Required flags cannot have default value")
|
| 385 |
+
else:
|
| 386 |
+
if not self._has_default_value():
|
| 387 |
+
raise ValueError("Optional flags must have a default value")
|
| 388 |
+
|
| 389 |
+
if self._has_default_value() and self.default_value is not None:
|
| 390 |
+
if not isinstance(self.default_value, self._get_dest_type()):
|
| 391 |
+
raise ValueError("Default value must be of the same type as the destination type")
|
| 392 |
+
|
| 393 |
+
|
| 394 |
+
class EnumFlagDef(SingleValueFlagDef):
|
| 395 |
+
"""Definition for a flag that takes a value from an Enum.
|
| 396 |
+
|
| 397 |
+
Sample usage:
|
| 398 |
+
# This defines a flag that can be specified on the command line as:
|
| 399 |
+
# --color=red
|
| 400 |
+
flag = SingleValueFlagDef(name="color", enum_type=ColorsEnum,
|
| 401 |
+
required=True)
|
| 402 |
+
flag.add_argument_to_parser(argument_parser)
|
| 403 |
+
"""
|
| 404 |
+
|
| 405 |
+
def __init__(self, *args, enum_type: type[enum.Enum], **kwargs):
|
| 406 |
+
if not issubclass(enum_type, enum.Enum):
|
| 407 |
+
raise TypeError('"enum_type" must be of type Enum')
|
| 408 |
+
|
| 409 |
+
# These properties are set by "enum_type" so don"t let the caller set them.
|
| 410 |
+
if "parse_type" in kwargs:
|
| 411 |
+
raise ValueError('Cannot set "parse_type" for EnumFlagDef; set "enum_type" instead')
|
| 412 |
+
kwargs["parse_type"] = str
|
| 413 |
+
|
| 414 |
+
if "dest_type" in kwargs:
|
| 415 |
+
raise ValueError('Cannot set "dest_type" for EnumFlagDef; set "enum_type" instead')
|
| 416 |
+
kwargs["dest_type"] = enum_type
|
| 417 |
+
|
| 418 |
+
if "choices" in kwargs:
|
| 419 |
+
# Verify that entries in `choices` are valid enum values.
|
| 420 |
+
for x in kwargs["choices"]:
|
| 421 |
+
try:
|
| 422 |
+
enum_type(x)
|
| 423 |
+
except ValueError:
|
| 424 |
+
raise ValueError('Invalid value in "choices": "{}"'.format(x)) from None
|
| 425 |
+
else:
|
| 426 |
+
kwargs["choices"] = [x.value for x in enum_type]
|
| 427 |
+
|
| 428 |
+
super().__init__(*args, **kwargs)
|
| 429 |
+
|
| 430 |
+
|
| 431 |
+
class MultiValuesFlagDef(FlagDef):
|
| 432 |
+
"""Definition for a flag that takes multiple values.
|
| 433 |
+
|
| 434 |
+
Sample usage:
|
| 435 |
+
# This defines a flag that can be specified on the command line as:
|
| 436 |
+
# --colors=red green blue
|
| 437 |
+
flag = MultiValuesFlagDef(name="colors", parse_type=str, required=True)
|
| 438 |
+
flag.add_argument_to_parser(argument_parser)
|
| 439 |
+
"""
|
| 440 |
+
|
| 441 |
+
def add_argument_to_parser(self, parser: argparse.ArgumentParser) -> None:
|
| 442 |
+
args = ["--" + self.name]
|
| 443 |
+
if self.short_name is not None:
|
| 444 |
+
args += ["-" + self.short_name]
|
| 445 |
+
|
| 446 |
+
kwargs = {}
|
| 447 |
+
if self.choices is not None:
|
| 448 |
+
kwargs["choices"] = self.choices
|
| 449 |
+
if self.help_msg is not None:
|
| 450 |
+
kwargs["help"] = self.help_msg
|
| 451 |
+
|
| 452 |
+
parser.add_argument(
|
| 453 |
+
*args,
|
| 454 |
+
action=_MultiValuesAppendAction,
|
| 455 |
+
type=self.parse_type,
|
| 456 |
+
dest_type=self._get_dest_type(),
|
| 457 |
+
parse_to_dest_type_fn=self._get_parse_to_dest_type_fn(),
|
| 458 |
+
required=self.required,
|
| 459 |
+
default=[],
|
| 460 |
+
nargs="+",
|
| 461 |
+
**kwargs,
|
| 462 |
+
)
|
| 463 |
+
|
| 464 |
+
def _do_additional_validation(self) -> None:
|
| 465 |
+
# No additional validation needed.
|
| 466 |
+
pass
|
| 467 |
+
|
| 468 |
+
|
| 469 |
+
@dataclasses.dataclass(frozen=True)
|
| 470 |
+
class BooleanFlagDef(FlagDef):
|
| 471 |
+
"""Definition for a Boolean flag.
|
| 472 |
+
|
| 473 |
+
A boolean flag is always optional with a default value of False. The flag does
|
| 474 |
+
not take any values. Specifying the flag on the commandline will set it to
|
| 475 |
+
True.
|
| 476 |
+
"""
|
| 477 |
+
|
| 478 |
+
def _do_additional_validation(self) -> None:
|
| 479 |
+
if self.dest_type is not None:
|
| 480 |
+
raise ValueError("dest_type cannot be set for BooleanFlagDef")
|
| 481 |
+
if self.parse_to_dest_type_fn is not None:
|
| 482 |
+
raise ValueError("parse_to_dest_type_fn cannot be set for BooleanFlagDef")
|
| 483 |
+
if self.choices is not None:
|
| 484 |
+
raise ValueError("choices cannot be set for BooleanFlagDef")
|
| 485 |
+
|
| 486 |
+
def add_argument_to_parser(self, parser: argparse.ArgumentParser) -> None:
|
| 487 |
+
args = ["--" + self.name]
|
| 488 |
+
if self.short_name is not None:
|
| 489 |
+
args += ["-" + self.short_name]
|
| 490 |
+
|
| 491 |
+
kwargs = {}
|
| 492 |
+
if self.help_msg is not None:
|
| 493 |
+
kwargs["help"] = self.help_msg
|
| 494 |
+
|
| 495 |
+
parser.add_argument(
|
| 496 |
+
*args,
|
| 497 |
+
action=_BooleanValueStoreAction,
|
| 498 |
+
type=bool,
|
| 499 |
+
required=False,
|
| 500 |
+
default=False,
|
| 501 |
+
nargs=0,
|
| 502 |
+
**kwargs,
|
| 503 |
+
)
|
.venv/lib/python3.11/site-packages/google/generativeai/notebook/gspread_client.py
ADDED
|
@@ -0,0 +1,232 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Copyright 2023 Google LLC
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Module that holds a global gspread.client.Client."""
|
| 16 |
+
from __future__ import annotations
|
| 17 |
+
|
| 18 |
+
import abc
|
| 19 |
+
import datetime
|
| 20 |
+
from typing import Any, Callable, Mapping, Sequence
|
| 21 |
+
from google.auth import credentials
|
| 22 |
+
from google.generativeai.notebook import html_utils
|
| 23 |
+
from google.generativeai.notebook import ipython_env
|
| 24 |
+
from google.generativeai.notebook import sheets_id
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
# The code may be running in an environment where the gspread library has not
|
| 28 |
+
# been installed.
|
| 29 |
+
_gspread_import_error: Exception | None = None
|
| 30 |
+
try:
|
| 31 |
+
# pylint: disable-next=g-import-not-at-top
|
| 32 |
+
import gspread
|
| 33 |
+
except ImportError as e:
|
| 34 |
+
_gspread_import_error = e
|
| 35 |
+
gspread = None
|
| 36 |
+
|
| 37 |
+
# Base class of exceptions that gspread.open(), open_by_url() and open_by_key()
|
| 38 |
+
# may throw.
|
| 39 |
+
GSpreadException = Exception if gspread is None else gspread.exceptions.GSpreadException # type: ignore
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class SpreadsheetNotFoundError(RuntimeError):
|
| 43 |
+
pass
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def _get_import_error() -> Exception:
|
| 47 |
+
return RuntimeError('"gspread" module not imported, got: {}'.format(_gspread_import_error))
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
class GSpreadClient(abc.ABC):
|
| 51 |
+
"""Wrapper around gspread.client.Client.
|
| 52 |
+
|
| 53 |
+
This adds a layer of indirection for us to inject mocks for testing.
|
| 54 |
+
"""
|
| 55 |
+
|
| 56 |
+
@abc.abstractmethod
|
| 57 |
+
def validate(self, sid: sheets_id.SheetsIdentifier) -> None:
|
| 58 |
+
"""Validates that `name` is the name of a Google Sheets document.
|
| 59 |
+
|
| 60 |
+
Raises an exception if false.
|
| 61 |
+
|
| 62 |
+
Args:
|
| 63 |
+
sid: The identifier for the document.
|
| 64 |
+
"""
|
| 65 |
+
|
| 66 |
+
@abc.abstractmethod
|
| 67 |
+
def get_all_records(
|
| 68 |
+
self,
|
| 69 |
+
sid: sheets_id.SheetsIdentifier,
|
| 70 |
+
worksheet_id: int,
|
| 71 |
+
) -> tuple[Sequence[Mapping[str, str]], Callable[[], None]]:
|
| 72 |
+
"""Returns all records for a Google Sheets worksheet."""
|
| 73 |
+
|
| 74 |
+
@abc.abstractmethod
|
| 75 |
+
def write_records(
|
| 76 |
+
self,
|
| 77 |
+
sid: sheets_id.SheetsIdentifier,
|
| 78 |
+
rows: Sequence[Sequence[Any]],
|
| 79 |
+
) -> None:
|
| 80 |
+
"""Writes results to a new worksheet to the Google Sheets document."""
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
class GSpreadClientImpl(GSpreadClient):
|
| 84 |
+
"""Concrete implementation of GSpreadClient."""
|
| 85 |
+
|
| 86 |
+
def __init__(self, client: Any, env: ipython_env.IPythonEnv | None):
|
| 87 |
+
"""Constructor.
|
| 88 |
+
|
| 89 |
+
Args:
|
| 90 |
+
client: Instance of gspread.client.Client.
|
| 91 |
+
env: Optional instance of IPythonEnv. This is used to display messages
|
| 92 |
+
such as the URL of the output Worksheet.
|
| 93 |
+
"""
|
| 94 |
+
self._client = client
|
| 95 |
+
self._ipython_env = env
|
| 96 |
+
|
| 97 |
+
def _open(self, sid: sheets_id.SheetsIdentifier):
|
| 98 |
+
"""Opens a Sheets document from `sid`.
|
| 99 |
+
|
| 100 |
+
Args:
|
| 101 |
+
sid: The identifier for the Sheets document.
|
| 102 |
+
|
| 103 |
+
Raises:
|
| 104 |
+
SpreadsheetNotFoundError: If the Sheets document cannot be found or
|
| 105 |
+
cannot be opened.
|
| 106 |
+
|
| 107 |
+
Returns:
|
| 108 |
+
A gspread.Worksheet instance representing the worksheet referred to by
|
| 109 |
+
`sid`.
|
| 110 |
+
"""
|
| 111 |
+
try:
|
| 112 |
+
if sid.name():
|
| 113 |
+
return self._client.open(sid.name())
|
| 114 |
+
if sid.key():
|
| 115 |
+
return self._client.open_by_key(str(sid.key()))
|
| 116 |
+
if sid.url():
|
| 117 |
+
return self._client.open_by_url(str(sid.url()))
|
| 118 |
+
except GSpreadException as exc:
|
| 119 |
+
raise SpreadsheetNotFoundError("Unable to find Sheets with {}".format(sid)) from exc
|
| 120 |
+
raise SpreadsheetNotFoundError("Invalid sheets_id.SheetsIdentifier")
|
| 121 |
+
|
| 122 |
+
def validate(self, sid: sheets_id.SheetsIdentifier) -> None:
|
| 123 |
+
self._open(sid)
|
| 124 |
+
|
| 125 |
+
def get_all_records(
|
| 126 |
+
self,
|
| 127 |
+
sid: sheets_id.SheetsIdentifier,
|
| 128 |
+
worksheet_id: int,
|
| 129 |
+
) -> tuple[Sequence[Mapping[str, str]], Callable[[], None]]:
|
| 130 |
+
sheet = self._open(sid)
|
| 131 |
+
worksheet = sheet.get_worksheet(worksheet_id)
|
| 132 |
+
|
| 133 |
+
if self._ipython_env is not None:
|
| 134 |
+
env = self._ipython_env
|
| 135 |
+
|
| 136 |
+
def _display_fn():
|
| 137 |
+
env.display_html(
|
| 138 |
+
"Reading inputs from worksheet {}".format(
|
| 139 |
+
html_utils.get_anchor_tag(
|
| 140 |
+
url=sheets_id.SheetsURL(worksheet.url),
|
| 141 |
+
text="{} in {}".format(worksheet.title, sheet.title),
|
| 142 |
+
)
|
| 143 |
+
)
|
| 144 |
+
)
|
| 145 |
+
|
| 146 |
+
else:
|
| 147 |
+
|
| 148 |
+
def _display_fn():
|
| 149 |
+
print("Reading inputs from worksheet {} in {}".format(worksheet.title, sheet.title))
|
| 150 |
+
|
| 151 |
+
return worksheet.get_all_records(), _display_fn
|
| 152 |
+
|
| 153 |
+
def write_records(
|
| 154 |
+
self,
|
| 155 |
+
sid: sheets_id.SheetsIdentifier,
|
| 156 |
+
rows: Sequence[Sequence[Any]],
|
| 157 |
+
) -> None:
|
| 158 |
+
sheet = self._open(sid)
|
| 159 |
+
|
| 160 |
+
# Create a new Worksheet.
|
| 161 |
+
# `title` has to be carefully constructed: some characters like colon ":"
|
| 162 |
+
# will not work with gspread in Worksheet.append_rows().
|
| 163 |
+
current_datetime = datetime.datetime.now()
|
| 164 |
+
title = f"Results {current_datetime:%Y_%m_%d} ({current_datetime:%s})"
|
| 165 |
+
|
| 166 |
+
# append_rows() will resize the worksheet as needed, so `rows` and `cols`
|
| 167 |
+
# can be set to 1 to create a worksheet with only a single cell.
|
| 168 |
+
worksheet = sheet.add_worksheet(title=title, rows=1, cols=1)
|
| 169 |
+
worksheet.append_rows(values=rows)
|
| 170 |
+
|
| 171 |
+
if self._ipython_env is not None:
|
| 172 |
+
self._ipython_env.display_html(
|
| 173 |
+
"Results written to new worksheet {}".format(
|
| 174 |
+
html_utils.get_anchor_tag(
|
| 175 |
+
url=sheets_id.SheetsURL(worksheet.url),
|
| 176 |
+
text="{} in {}".format(worksheet.title, sheet.title),
|
| 177 |
+
)
|
| 178 |
+
)
|
| 179 |
+
)
|
| 180 |
+
else:
|
| 181 |
+
print("Results written to new worksheet {} in {}".format(worksheet.title, sheet.title))
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
class NullGSpreadClient(GSpreadClient):
|
| 185 |
+
"""Null-object implementation of GSpreadClient.
|
| 186 |
+
|
| 187 |
+
This class raises an error if any of its methods are called. It is used when
|
| 188 |
+
the gspread library is not available.
|
| 189 |
+
"""
|
| 190 |
+
|
| 191 |
+
def validate(self, sid: sheets_id.SheetsIdentifier) -> None:
|
| 192 |
+
raise _get_import_error()
|
| 193 |
+
|
| 194 |
+
def get_all_records(
|
| 195 |
+
self,
|
| 196 |
+
sid: sheets_id.SheetsIdentifier,
|
| 197 |
+
worksheet_id: int,
|
| 198 |
+
) -> tuple[Sequence[Mapping[str, str]], Callable[[], None]]:
|
| 199 |
+
raise _get_import_error()
|
| 200 |
+
|
| 201 |
+
def write_records(
|
| 202 |
+
self,
|
| 203 |
+
sid: sheets_id.SheetsIdentifier,
|
| 204 |
+
rows: Sequence[Sequence[Any]],
|
| 205 |
+
) -> None:
|
| 206 |
+
raise _get_import_error()
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
# Global instance of gspread client.
|
| 210 |
+
_gspread_client: GSpreadClient | None = None
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
def authorize(creds: credentials.Credentials, env: ipython_env.IPythonEnv | None) -> None:
|
| 214 |
+
"""Sets up credential for gspreads."""
|
| 215 |
+
global _gspread_client
|
| 216 |
+
if gspread is not None:
|
| 217 |
+
client = gspread.authorize(creds) # type: ignore
|
| 218 |
+
_gspread_client = GSpreadClientImpl(client=client, env=env)
|
| 219 |
+
else:
|
| 220 |
+
_gspread_client = NullGSpreadClient()
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
def get_client() -> GSpreadClient:
|
| 224 |
+
if not _gspread_client:
|
| 225 |
+
raise RuntimeError("Must call authorize() first")
|
| 226 |
+
return _gspread_client
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
def testonly_set_client(client: GSpreadClient) -> None:
|
| 230 |
+
"""Overrides the global client for testing."""
|
| 231 |
+
global _gspread_client
|
| 232 |
+
_gspread_client = client
|
.venv/lib/python3.11/site-packages/google/generativeai/notebook/html_utils.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Copyright 2023 Google LLC
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Utilities for generating HTML."""
|
| 16 |
+
from __future__ import annotations
|
| 17 |
+
|
| 18 |
+
from xml.etree import ElementTree
|
| 19 |
+
from google.generativeai.notebook import sheets_id
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def get_anchor_tag(url: sheets_id.SheetsURL, text: str) -> str:
|
| 23 |
+
"""Returns a HTML string representing an anchor tag.
|
| 24 |
+
|
| 25 |
+
This class uses the xml.etree library to handle HTML escaping.
|
| 26 |
+
|
| 27 |
+
Args:
|
| 28 |
+
url: The Sheets URL to link to.
|
| 29 |
+
text: The text body of the link.
|
| 30 |
+
|
| 31 |
+
Returns:
|
| 32 |
+
A string representing a HTML fragment.
|
| 33 |
+
"""
|
| 34 |
+
tag = ElementTree.Element(
|
| 35 |
+
"a",
|
| 36 |
+
attrib={
|
| 37 |
+
# Open in a new window/tab
|
| 38 |
+
"target": "_blank",
|
| 39 |
+
# See:
|
| 40 |
+
# https://developer.chrome.com/en/docs/lighthouse/best-practices/external-anchors-use-rel-noopener/
|
| 41 |
+
"rel": "noopener",
|
| 42 |
+
"href": str(url),
|
| 43 |
+
},
|
| 44 |
+
)
|
| 45 |
+
tag.text = text if text else "link"
|
| 46 |
+
return ElementTree.tostring(tag, encoding="unicode", method="html")
|
.venv/lib/python3.11/site-packages/google/generativeai/notebook/input_utils.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Copyright 2023 Google LLC
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Utilities for handling input variables."""
|
| 16 |
+
from __future__ import annotations
|
| 17 |
+
|
| 18 |
+
from typing import Callable, Mapping
|
| 19 |
+
|
| 20 |
+
from google.generativeai.notebook import parsed_args_lib
|
| 21 |
+
from google.generativeai.notebook import py_utils
|
| 22 |
+
from google.generativeai.notebook.lib import llmfn_input_utils
|
| 23 |
+
from google.generativeai.notebook.lib import llmfn_inputs_source
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class _NormalizedInputsSource(llmfn_inputs_source.LLMFnInputsSource):
|
| 27 |
+
"""Wrapper around NormalizedInputsList.
|
| 28 |
+
|
| 29 |
+
By design LLMFunction does not take NormalizedInputsList as input because
|
| 30 |
+
NormalizedInputsList is an internal representation so we want to minimize
|
| 31 |
+
exposure to the caller.
|
| 32 |
+
|
| 33 |
+
When we have inputs already in normalized format (e.g. from
|
| 34 |
+
join_prompt_inputs()) we can wrap it as an LLMFnInputsSource to pass as an
|
| 35 |
+
input to LLMFunction.
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
def __init__(self, normalized_inputs: llmfn_inputs_source.NormalizedInputsList):
|
| 39 |
+
super().__init__()
|
| 40 |
+
self._normalized_inputs = normalized_inputs
|
| 41 |
+
|
| 42 |
+
def _to_normalized_inputs_impl(
|
| 43 |
+
self,
|
| 44 |
+
) -> tuple[llmfn_inputs_source.NormalizedInputsList, Callable[[], None]]:
|
| 45 |
+
return self._normalized_inputs, lambda: None
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def get_inputs_source_from_py_var(
|
| 49 |
+
var_name: str,
|
| 50 |
+
) -> llmfn_inputs_source.LLMFnInputsSource:
|
| 51 |
+
data = py_utils.get_py_var(var_name)
|
| 52 |
+
if isinstance(data, llmfn_inputs_source.LLMFnInputsSource):
|
| 53 |
+
# No conversion needed.
|
| 54 |
+
return data
|
| 55 |
+
normalized_inputs = llmfn_input_utils.to_normalized_inputs(data)
|
| 56 |
+
return _NormalizedInputsSource(normalized_inputs)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def join_inputs_sources(
|
| 60 |
+
parsed_args: parsed_args_lib.ParsedArgs,
|
| 61 |
+
suppress_status_msgs: bool = False,
|
| 62 |
+
) -> llmfn_inputs_source.LLMFnInputsSource:
|
| 63 |
+
"""Get a single combined input source from `parsed_args."""
|
| 64 |
+
combined_inputs: list[Mapping[str, str]] = []
|
| 65 |
+
for source in parsed_args.inputs:
|
| 66 |
+
combined_inputs.extend(
|
| 67 |
+
source.to_normalized_inputs(suppress_status_msgs=suppress_status_msgs)
|
| 68 |
+
)
|
| 69 |
+
for source in parsed_args.sheets_input_names:
|
| 70 |
+
combined_inputs.extend(
|
| 71 |
+
source.to_normalized_inputs(suppress_status_msgs=suppress_status_msgs)
|
| 72 |
+
)
|
| 73 |
+
return _NormalizedInputsSource(combined_inputs)
|
.venv/lib/python3.11/site-packages/google/generativeai/notebook/ipython_env_impl.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Copyright 2023 Google LLC
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""IPythonEnvImpl."""
|
| 16 |
+
from __future__ import annotations
|
| 17 |
+
|
| 18 |
+
from typing import Any
|
| 19 |
+
from google.generativeai.notebook import ipython_env
|
| 20 |
+
from IPython.core import display as ipython_display
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class IPythonEnvImpl(ipython_env.IPythonEnv):
|
| 24 |
+
"""Concrete implementation of IPythonEnv."""
|
| 25 |
+
|
| 26 |
+
def display(self, x: Any) -> None:
|
| 27 |
+
ipython_display.display(x)
|
| 28 |
+
|
| 29 |
+
def display_html(self, x: str) -> None:
|
| 30 |
+
ipython_display.display(ipython_display.HTML(x))
|
.venv/lib/python3.11/site-packages/google/generativeai/notebook/magics.py
ADDED
|
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Copyright 2023 Google LLC
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Colab Magics class.
|
| 16 |
+
|
| 17 |
+
Installs %%llm magics.
|
| 18 |
+
"""
|
| 19 |
+
from __future__ import annotations
|
| 20 |
+
|
| 21 |
+
import abc
|
| 22 |
+
|
| 23 |
+
from google.auth import credentials
|
| 24 |
+
from google.generativeai import client as genai
|
| 25 |
+
from google.generativeai.notebook import gspread_client
|
| 26 |
+
from google.generativeai.notebook import ipython_env
|
| 27 |
+
from google.generativeai.notebook import ipython_env_impl
|
| 28 |
+
from google.generativeai.notebook import magics_engine
|
| 29 |
+
from google.generativeai.notebook import post_process_utils
|
| 30 |
+
from google.generativeai.notebook import sheets_utils
|
| 31 |
+
|
| 32 |
+
import IPython
|
| 33 |
+
from IPython.core import magic
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
# Set the UA to distinguish the magic from the client. Do this at import-time
|
| 37 |
+
# so that a user can still call `genai.configure()`, and both their settings
|
| 38 |
+
# and this are honored.
|
| 39 |
+
genai.USER_AGENT = "genai-py-magic"
|
| 40 |
+
|
| 41 |
+
SheetsInputs = sheets_utils.SheetsInputs
|
| 42 |
+
SheetsOutputs = sheets_utils.SheetsOutputs
|
| 43 |
+
|
| 44 |
+
# Decorator functions for post-processing.
|
| 45 |
+
post_process_add_fn = post_process_utils.post_process_add_fn
|
| 46 |
+
post_process_replace_fn = post_process_utils.post_process_replace_fn
|
| 47 |
+
|
| 48 |
+
# Globals.
|
| 49 |
+
_ipython_env: ipython_env.IPythonEnv | None = None
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def _get_ipython_env() -> ipython_env.IPythonEnv:
|
| 53 |
+
"""Lazily constructs and returns a global IPythonEnv instance."""
|
| 54 |
+
global _ipython_env
|
| 55 |
+
if _ipython_env is None:
|
| 56 |
+
_ipython_env = ipython_env_impl.IPythonEnvImpl()
|
| 57 |
+
return _ipython_env
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def authorize(creds: credentials.Credentials) -> None:
|
| 61 |
+
"""Sets up credentials.
|
| 62 |
+
|
| 63 |
+
This is used for interacting Google APIs, such as Google Sheets.
|
| 64 |
+
|
| 65 |
+
Args:
|
| 66 |
+
creds: The credentials that will be used (e.g. to read from Google Sheets.)
|
| 67 |
+
"""
|
| 68 |
+
gspread_client.authorize(creds=creds, env=_get_ipython_env())
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
class AbstractMagics(abc.ABC):
|
| 72 |
+
"""Defines interface to Magics class."""
|
| 73 |
+
|
| 74 |
+
@abc.abstractmethod
|
| 75 |
+
def llm(self, cell_line: str | None, cell_body: str | None):
|
| 76 |
+
"""Perform various LLM-related operations.
|
| 77 |
+
|
| 78 |
+
Args:
|
| 79 |
+
cell_line: String to pass to the MagicsEngine.
|
| 80 |
+
cell_body: Contents of the cell body.
|
| 81 |
+
"""
|
| 82 |
+
raise NotImplementedError()
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
class MagicsImpl(AbstractMagics):
|
| 86 |
+
"""Actual class implementing the magics functionality.
|
| 87 |
+
|
| 88 |
+
We use a separate class to ensure a single, global instance
|
| 89 |
+
of the magics class.
|
| 90 |
+
"""
|
| 91 |
+
|
| 92 |
+
def __init__(self):
|
| 93 |
+
self._engine = magics_engine.MagicsEngine(env=_get_ipython_env())
|
| 94 |
+
|
| 95 |
+
def llm(self, cell_line: str | None, cell_body: str | None):
|
| 96 |
+
"""Perform various LLM-related operations.
|
| 97 |
+
|
| 98 |
+
Args:
|
| 99 |
+
cell_line: String to pass to the MagicsEngine.
|
| 100 |
+
cell_body: Contents of the cell body.
|
| 101 |
+
|
| 102 |
+
Returns:
|
| 103 |
+
Results from running MagicsEngine.
|
| 104 |
+
"""
|
| 105 |
+
cell_line = cell_line or ""
|
| 106 |
+
cell_body = cell_body or ""
|
| 107 |
+
return self._engine.execute_cell(cell_line, cell_body)
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
@magic.magics_class
|
| 111 |
+
class Magics(magic.Magics):
|
| 112 |
+
"""Class to register the magic with Colab.
|
| 113 |
+
|
| 114 |
+
Objects of this class delegate all calls to a single,
|
| 115 |
+
global instance.
|
| 116 |
+
"""
|
| 117 |
+
|
| 118 |
+
# Global instance
|
| 119 |
+
_instance = None
|
| 120 |
+
|
| 121 |
+
@classmethod
|
| 122 |
+
def get_instance(cls) -> AbstractMagics:
|
| 123 |
+
"""Retrieve global instance of the Magics object."""
|
| 124 |
+
if cls._instance is None:
|
| 125 |
+
cls._instance = MagicsImpl()
|
| 126 |
+
return cls._instance
|
| 127 |
+
|
| 128 |
+
@magic.line_cell_magic
|
| 129 |
+
def llm(self, cell_line: str | None, cell_body: str | None):
|
| 130 |
+
"""Perform various LLM-related operations.
|
| 131 |
+
|
| 132 |
+
Args:
|
| 133 |
+
cell_line: String to pass to the MagicsEngine.
|
| 134 |
+
cell_body: Contents of the cell body.
|
| 135 |
+
|
| 136 |
+
Returns:
|
| 137 |
+
Results from running MagicsEngine.
|
| 138 |
+
"""
|
| 139 |
+
return Magics.get_instance().llm(cell_line=cell_line, cell_body=cell_body)
|
| 140 |
+
|
| 141 |
+
@magic.line_cell_magic
|
| 142 |
+
def palm(self, cell_line: str | None, cell_body: str | None):
|
| 143 |
+
return self.llm(cell_line, cell_body)
|
| 144 |
+
|
| 145 |
+
@magic.line_cell_magic
|
| 146 |
+
def gemini(self, cell_line: str | None, cell_body: str | None):
|
| 147 |
+
return self.llm(cell_line, cell_body)
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
IPython.get_ipython().register_magics(Magics)
|
.venv/lib/python3.11/site-packages/google/generativeai/notebook/magics_engine.py
ADDED
|
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Copyright 2023 Google LLC
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""MagicsEngine class."""
|
| 16 |
+
from __future__ import annotations
|
| 17 |
+
|
| 18 |
+
from typing import AbstractSet, Sequence
|
| 19 |
+
|
| 20 |
+
from google.generativeai.notebook import argument_parser
|
| 21 |
+
from google.generativeai.notebook import cmd_line_parser
|
| 22 |
+
from google.generativeai.notebook import command
|
| 23 |
+
from google.generativeai.notebook import compare_cmd
|
| 24 |
+
from google.generativeai.notebook import compile_cmd
|
| 25 |
+
from google.generativeai.notebook import eval_cmd
|
| 26 |
+
from google.generativeai.notebook import ipython_env
|
| 27 |
+
from google.generativeai.notebook import model_registry
|
| 28 |
+
from google.generativeai.notebook import parsed_args_lib
|
| 29 |
+
from google.generativeai.notebook import post_process_utils
|
| 30 |
+
from google.generativeai.notebook import run_cmd
|
| 31 |
+
from google.generativeai.notebook.lib import prompt_utils
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class MagicsEngine:
|
| 35 |
+
"""Implementation of functionality used by Magics.
|
| 36 |
+
|
| 37 |
+
This class provides the implementation for Magics, decoupled from the
|
| 38 |
+
details of integrating with Colab Magics such as registration.
|
| 39 |
+
"""
|
| 40 |
+
|
| 41 |
+
def __init__(
|
| 42 |
+
self,
|
| 43 |
+
registry: model_registry.ModelRegistry | None = None,
|
| 44 |
+
env: ipython_env.IPythonEnv | None = None,
|
| 45 |
+
):
|
| 46 |
+
self._ipython_env = env
|
| 47 |
+
models = registry or model_registry.ModelRegistry()
|
| 48 |
+
self._cmd_handlers: dict[parsed_args_lib.CommandName, command.Command] = {
|
| 49 |
+
parsed_args_lib.CommandName.RUN_CMD: run_cmd.RunCommand(models=models, env=env),
|
| 50 |
+
parsed_args_lib.CommandName.COMPILE_CMD: compile_cmd.CompileCommand(
|
| 51 |
+
models=models, env=env
|
| 52 |
+
),
|
| 53 |
+
parsed_args_lib.CommandName.COMPARE_CMD: compare_cmd.CompareCommand(env=env),
|
| 54 |
+
parsed_args_lib.CommandName.EVAL_CMD: eval_cmd.EvalCommand(models=models, env=env),
|
| 55 |
+
}
|
| 56 |
+
|
| 57 |
+
def parse_line(
|
| 58 |
+
self,
|
| 59 |
+
line: str,
|
| 60 |
+
placeholders: AbstractSet[str],
|
| 61 |
+
) -> tuple[parsed_args_lib.ParsedArgs, parsed_args_lib.PostProcessingTokens]:
|
| 62 |
+
return cmd_line_parser.CmdLineParser().parse_line(line, placeholders)
|
| 63 |
+
|
| 64 |
+
def _get_handler(self, line: str, placeholders: AbstractSet[str]) -> tuple[
|
| 65 |
+
command.Command,
|
| 66 |
+
parsed_args_lib.ParsedArgs,
|
| 67 |
+
Sequence[post_process_utils.ParsedPostProcessExpr],
|
| 68 |
+
]:
|
| 69 |
+
"""Given the command line, parse and return all components.
|
| 70 |
+
|
| 71 |
+
Args:
|
| 72 |
+
line: The LLM Magics command line.
|
| 73 |
+
placeholders: Placeholders from prompts in the cell contents.
|
| 74 |
+
|
| 75 |
+
Returns:
|
| 76 |
+
A three-tuple containing:
|
| 77 |
+
- The command (e.g. "run")
|
| 78 |
+
- Parsed arguments for the command,
|
| 79 |
+
- Parsed post-processing expressions
|
| 80 |
+
"""
|
| 81 |
+
parsed_args, post_processing_tokens = self.parse_line(line, placeholders)
|
| 82 |
+
cmd_name = parsed_args.cmd
|
| 83 |
+
handler = self._cmd_handlers[cmd_name]
|
| 84 |
+
post_processing_fns = handler.parse_post_processing_tokens(post_processing_tokens)
|
| 85 |
+
return handler, parsed_args, post_processing_fns
|
| 86 |
+
|
| 87 |
+
def execute_cell(self, line: str, cell_content: str):
|
| 88 |
+
"""Executes the supplied magic line and cell payload."""
|
| 89 |
+
cell = _clean_cell(cell_content)
|
| 90 |
+
placeholders = prompt_utils.get_placeholders(cell)
|
| 91 |
+
|
| 92 |
+
try:
|
| 93 |
+
handler, parsed_args, post_processing_fns = self._get_handler(line, placeholders)
|
| 94 |
+
return handler.execute(parsed_args, cell, post_processing_fns)
|
| 95 |
+
except argument_parser.ParserNormalExit as e:
|
| 96 |
+
if self._ipython_env is not None:
|
| 97 |
+
e.set_ipython_env(self._ipython_env)
|
| 98 |
+
# ParserNormalExit implements the _ipython_display_ method so it can
|
| 99 |
+
# be returned as the output of this cell for display.
|
| 100 |
+
return e
|
| 101 |
+
except argument_parser.ParserError as e:
|
| 102 |
+
e.display(self._ipython_env)
|
| 103 |
+
# Raise an exception to indicate that execution for this cell has
|
| 104 |
+
# failed.
|
| 105 |
+
# The exception is re-raised as SystemExit because Colab automatically
|
| 106 |
+
# suppresses traceback for SystemExit but not other exceptions. Because
|
| 107 |
+
# ParserErrors are usually due to user error (e.g. a missing required
|
| 108 |
+
# flag or an invalid flag value), we want to hide the traceback to
|
| 109 |
+
# avoid detracting the user from the error message, and we want to
|
| 110 |
+
# reserve exceptions-with-traceback for actual bugs and unexpected
|
| 111 |
+
# errors.
|
| 112 |
+
error_msg = "Got parser error: {}".format(e.msgs()[-1]) if e.msgs() else ""
|
| 113 |
+
raise SystemExit(error_msg) from e
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def _clean_cell(cell_content: str) -> str:
|
| 117 |
+
# Colab includes a trailing newline in cell_content. Remove only the last
|
| 118 |
+
# line break from cell contents (i.e. not rstrip), so that multi-line and
|
| 119 |
+
# intentional line breaks are preserved, but single-line prompts don't have
|
| 120 |
+
# a trailing line break.
|
| 121 |
+
cell = cell_content
|
| 122 |
+
if cell.endswith("\n"):
|
| 123 |
+
cell = cell[:-1]
|
| 124 |
+
return cell
|
.venv/lib/python3.11/site-packages/google/generativeai/notebook/model_registry.py
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Copyright 2023 Google LLC
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Maintains set of LLM models that can be instantiated by name."""
|
| 16 |
+
from __future__ import annotations
|
| 17 |
+
|
| 18 |
+
import enum
|
| 19 |
+
from typing import Callable
|
| 20 |
+
|
| 21 |
+
from google.generativeai.notebook import text_model
|
| 22 |
+
from google.generativeai.notebook.lib import model as model_lib
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class ModelName(enum.Enum):
|
| 26 |
+
ECHO_MODEL = "echo"
|
| 27 |
+
TEXT_MODEL = "text"
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class ModelRegistry:
|
| 31 |
+
"""Registry that instantiates and caches models."""
|
| 32 |
+
|
| 33 |
+
DEFAULT_MODEL = ModelName.TEXT_MODEL
|
| 34 |
+
|
| 35 |
+
def __init__(self):
|
| 36 |
+
self._model_cache: dict[ModelName, model_lib.AbstractModel] = {}
|
| 37 |
+
self._model_constructors: dict[ModelName, Callable[[], model_lib.AbstractModel]] = {
|
| 38 |
+
ModelName.ECHO_MODEL: model_lib.EchoModel,
|
| 39 |
+
ModelName.TEXT_MODEL: text_model.TextModel,
|
| 40 |
+
}
|
| 41 |
+
|
| 42 |
+
def get_model(self, model_name: ModelName) -> model_lib.AbstractModel:
|
| 43 |
+
"""Given `model_name`, return the corresponding Model instance.
|
| 44 |
+
|
| 45 |
+
Model instances are cached and reused for the same `model_name`.
|
| 46 |
+
|
| 47 |
+
Args:
|
| 48 |
+
model_name: The name of the model.
|
| 49 |
+
|
| 50 |
+
Returns:
|
| 51 |
+
The corresponding model instance for `model_name`.
|
| 52 |
+
"""
|
| 53 |
+
if model_name not in self._model_cache:
|
| 54 |
+
self._model_cache[model_name] = self._model_constructors[model_name]()
|
| 55 |
+
return self._model_cache[model_name]
|
.venv/lib/python3.11/site-packages/google/generativeai/notebook/parsed_args_lib.py
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Copyright 2023 Google LLC
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Results from parsing the commandline.
|
| 16 |
+
|
| 17 |
+
This module separates the results from commandline parsing from the parser
|
| 18 |
+
itself so that classes that operate on the results (e.g. the subclasses of
|
| 19 |
+
Command) do not have to depend on the commandline parser as well.
|
| 20 |
+
"""
|
| 21 |
+
from __future__ import annotations
|
| 22 |
+
|
| 23 |
+
import dataclasses
|
| 24 |
+
import enum
|
| 25 |
+
from typing import Any, Callable, Sequence
|
| 26 |
+
|
| 27 |
+
from google.generativeai.notebook import model_registry
|
| 28 |
+
from google.generativeai.notebook.lib import llm_function
|
| 29 |
+
from google.generativeai.notebook.lib import llmfn_inputs_source
|
| 30 |
+
from google.generativeai.notebook.lib import llmfn_outputs
|
| 31 |
+
from google.generativeai.notebook.lib import model as model_lib
|
| 32 |
+
|
| 33 |
+
# Post processing tokens are represented as a sequence of sequence of tokens,
|
| 34 |
+
# because the pipe operator could be used more than once.
|
| 35 |
+
PostProcessingTokens = Sequence[Sequence[str]]
|
| 36 |
+
|
| 37 |
+
# The type of function taken by the "compare_fn" flag.
|
| 38 |
+
# It takes the text_results of the left- and right-hand side functions as
|
| 39 |
+
# inputs and returns a comparison result.
|
| 40 |
+
TextResultCompareFn = Callable[[str, str], Any]
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class CommandName(enum.Enum):
|
| 44 |
+
RUN_CMD = "run"
|
| 45 |
+
COMPILE_CMD = "compile"
|
| 46 |
+
COMPARE_CMD = "compare"
|
| 47 |
+
EVAL_CMD = "eval"
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
@dataclasses.dataclass(frozen=True)
|
| 51 |
+
class ParsedArgs:
|
| 52 |
+
"""The results of parsing the command line."""
|
| 53 |
+
|
| 54 |
+
cmd: CommandName
|
| 55 |
+
|
| 56 |
+
# For run, compile and eval commands.
|
| 57 |
+
model_args: model_lib.ModelArguments
|
| 58 |
+
model_type: model_registry.ModelName | None = None
|
| 59 |
+
unique: bool = False
|
| 60 |
+
|
| 61 |
+
# For run, compare and eval commands.
|
| 62 |
+
inputs: Sequence[llmfn_inputs_source.LLMFnInputsSource] = dataclasses.field(
|
| 63 |
+
default_factory=list
|
| 64 |
+
)
|
| 65 |
+
sheets_input_names: Sequence[llmfn_inputs_source.LLMFnInputsSource] = dataclasses.field(
|
| 66 |
+
default_factory=list
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
outputs: Sequence[llmfn_outputs.LLMFnOutputsSink] = dataclasses.field(default_factory=list)
|
| 70 |
+
sheets_output_names: Sequence[llmfn_outputs.LLMFnOutputsSink] = dataclasses.field(
|
| 71 |
+
default_factory=list
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
# For compile command.
|
| 75 |
+
compile_save_name: str | None = None
|
| 76 |
+
|
| 77 |
+
# For compare command.
|
| 78 |
+
lhs_name_and_fn: tuple[str, llm_function.LLMFunction] | None = None
|
| 79 |
+
rhs_name_and_fn: tuple[str, llm_function.LLMFunction] | None = None
|
| 80 |
+
|
| 81 |
+
# For compare and eval commands.
|
| 82 |
+
compare_fn: Sequence[tuple[str, TextResultCompareFn]] = dataclasses.field(default_factory=list)
|
| 83 |
+
|
| 84 |
+
# For eval command.
|
| 85 |
+
ground_truth: Sequence[str] = dataclasses.field(default_factory=list)
|
.venv/lib/python3.11/site-packages/google/generativeai/notebook/post_process_utils_test_helper.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Copyright 2023 Google LLC
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Helper module for post_process_utils_test."""
|
| 16 |
+
from __future__ import annotations
|
| 17 |
+
|
| 18 |
+
from google.generativeai.notebook import post_process_utils
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def add_length(x: str) -> int:
|
| 22 |
+
return len(x)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
@post_process_utils.post_process_add_fn
|
| 26 |
+
def add_length_decorated(x: str) -> int:
|
| 27 |
+
return len(x)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
@post_process_utils.post_process_replace_fn
|
| 31 |
+
def to_upper(x: str) -> str:
|
| 32 |
+
return x.upper()
|
.venv/lib/python3.11/site-packages/google/generativeai/notebook/run_cmd.py
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Copyright 2023 Google LLC
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""The run command."""
|
| 16 |
+
from __future__ import annotations
|
| 17 |
+
|
| 18 |
+
from typing import Sequence
|
| 19 |
+
|
| 20 |
+
from google.generativeai.notebook import command
|
| 21 |
+
from google.generativeai.notebook import command_utils
|
| 22 |
+
from google.generativeai.notebook import input_utils
|
| 23 |
+
from google.generativeai.notebook import ipython_env
|
| 24 |
+
from google.generativeai.notebook import model_registry
|
| 25 |
+
from google.generativeai.notebook import output_utils
|
| 26 |
+
from google.generativeai.notebook import parsed_args_lib
|
| 27 |
+
from google.generativeai.notebook import post_process_utils
|
| 28 |
+
import pandas
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class RunCommand(command.Command):
|
| 32 |
+
"""Implementation of the "run" command."""
|
| 33 |
+
|
| 34 |
+
def __init__(
|
| 35 |
+
self,
|
| 36 |
+
models: model_registry.ModelRegistry,
|
| 37 |
+
env: ipython_env.IPythonEnv | None = None,
|
| 38 |
+
):
|
| 39 |
+
"""Constructor.
|
| 40 |
+
|
| 41 |
+
Args:
|
| 42 |
+
models: ModelRegistry instance.
|
| 43 |
+
env: The IPythonEnv environment.
|
| 44 |
+
"""
|
| 45 |
+
super().__init__()
|
| 46 |
+
self._models = models
|
| 47 |
+
self._ipython_env = env
|
| 48 |
+
|
| 49 |
+
def execute(
|
| 50 |
+
self,
|
| 51 |
+
parsed_args: parsed_args_lib.ParsedArgs,
|
| 52 |
+
cell_content: str,
|
| 53 |
+
post_processing_fns: Sequence[post_process_utils.ParsedPostProcessExpr],
|
| 54 |
+
) -> pandas.DataFrame:
|
| 55 |
+
# We expect CmdLineParser to have already read the inputs once to validate
|
| 56 |
+
# that the placeholders in the prompt are present in the inputs, so we can
|
| 57 |
+
# suppress the status messages here.
|
| 58 |
+
inputs = input_utils.join_inputs_sources(parsed_args, suppress_status_msgs=True)
|
| 59 |
+
|
| 60 |
+
llm_fn = command_utils.create_llm_function(
|
| 61 |
+
models=self._models,
|
| 62 |
+
env=self._ipython_env,
|
| 63 |
+
parsed_args=parsed_args,
|
| 64 |
+
cell_content=cell_content,
|
| 65 |
+
post_processing_fns=post_processing_fns,
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
results = llm_fn(inputs=inputs)
|
| 69 |
+
output_utils.write_to_outputs(results=results, parsed_args=parsed_args)
|
| 70 |
+
return results.as_pandas_dataframe()
|
| 71 |
+
|
| 72 |
+
def parse_post_processing_tokens(
|
| 73 |
+
self, tokens: Sequence[Sequence[str]]
|
| 74 |
+
) -> Sequence[post_process_utils.ParsedPostProcessExpr]:
|
| 75 |
+
return post_process_utils.resolve_post_processing_tokens(tokens)
|
.venv/lib/python3.11/site-packages/google/generativeai/notebook/sheets_id.py
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Copyright 2023 Google LLC
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Module for classes related to identifying a Sheets document."""
|
| 16 |
+
from __future__ import annotations
|
| 17 |
+
|
| 18 |
+
import re
|
| 19 |
+
from google.generativeai.notebook import sheets_sanitize_url
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def _sanitize_key(key: str) -> str:
|
| 23 |
+
if not re.fullmatch("[a-zA-Z0-9_-]+", key):
|
| 24 |
+
raise ValueError('"{}" is not a valid Sheets key'.format(key))
|
| 25 |
+
return key
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class SheetsURL:
|
| 29 |
+
"""Class that enforces safety by ensuring that URLs are sanitized."""
|
| 30 |
+
|
| 31 |
+
def __init__(self, url: str):
|
| 32 |
+
self._url: str = sheets_sanitize_url.sanitize_sheets_url(url)
|
| 33 |
+
|
| 34 |
+
def __str__(self) -> str:
|
| 35 |
+
return self._url
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class SheetsKey:
|
| 39 |
+
"""Class that enforces safety by ensuring that keys are sanitized."""
|
| 40 |
+
|
| 41 |
+
def __init__(self, key: str):
|
| 42 |
+
self._key: str = _sanitize_key(key)
|
| 43 |
+
|
| 44 |
+
def __str__(self) -> str:
|
| 45 |
+
return self._key
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class SheetsIdentifier:
|
| 49 |
+
"""Encapsulates a means to identify a Sheets document.
|
| 50 |
+
|
| 51 |
+
The gspread library provides three ways to look up a Sheets document: by name,
|
| 52 |
+
by url and by key. An instance of this class represents exactly one of the
|
| 53 |
+
methods.
|
| 54 |
+
"""
|
| 55 |
+
|
| 56 |
+
def __init__(
|
| 57 |
+
self,
|
| 58 |
+
name: str | None = None,
|
| 59 |
+
key: SheetsKey | None = None,
|
| 60 |
+
url: SheetsURL | None = None,
|
| 61 |
+
):
|
| 62 |
+
"""Constructor.
|
| 63 |
+
|
| 64 |
+
Exactly one of the arguments should be provided.
|
| 65 |
+
|
| 66 |
+
Args:
|
| 67 |
+
name: The name of the Sheets document. More-than-one Sheets documents can
|
| 68 |
+
have the same name, so this is the least precise method of identifying
|
| 69 |
+
the document.
|
| 70 |
+
key: The key of the Sheets document
|
| 71 |
+
url: The url to the Sheets document
|
| 72 |
+
|
| 73 |
+
Raises:
|
| 74 |
+
ValueError: If the caller does not specify exactly one of name, url or
|
| 75 |
+
key.
|
| 76 |
+
"""
|
| 77 |
+
self._name = name
|
| 78 |
+
self._key = key
|
| 79 |
+
self._url = url
|
| 80 |
+
|
| 81 |
+
# There should be exactly one.
|
| 82 |
+
num_inputs = int(bool(self._name)) + int(bool(self._key)) + int(bool(self._url))
|
| 83 |
+
if num_inputs != 1:
|
| 84 |
+
raise ValueError("Must set exactly one of name, key or url")
|
| 85 |
+
|
| 86 |
+
def name(self) -> str | None:
|
| 87 |
+
return self._name
|
| 88 |
+
|
| 89 |
+
def key(self) -> SheetsKey | None:
|
| 90 |
+
return self._key
|
| 91 |
+
|
| 92 |
+
def url(self) -> SheetsURL | None:
|
| 93 |
+
return self._url
|
| 94 |
+
|
| 95 |
+
def __str__(self):
|
| 96 |
+
if self._name:
|
| 97 |
+
return "name={}".format(self._name)
|
| 98 |
+
elif self._key:
|
| 99 |
+
return "key={}".format(self._key)
|
| 100 |
+
else:
|
| 101 |
+
return "url={}".format(self._url)
|
.venv/lib/python3.11/site-packages/google/generativeai/notebook/sheets_sanitize_url.py
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Copyright 2023 Google LLC
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Utilities for working with URLs."""
|
| 16 |
+
from __future__ import annotations
|
| 17 |
+
|
| 18 |
+
import re
|
| 19 |
+
from urllib import parse
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def _validate_url_part(part: str) -> None:
|
| 23 |
+
if not re.fullmatch("[a-zA-Z0-9_-]*", part):
|
| 24 |
+
raise ValueError('"{}" is outside the restricted character set'.format(part))
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def _validate_url_query_or_fragment(part: str) -> None:
|
| 28 |
+
for key, values in parse.parse_qs(part).items():
|
| 29 |
+
_validate_url_part(key)
|
| 30 |
+
for value in values:
|
| 31 |
+
_validate_url_part(value)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def sanitize_sheets_url(url: str) -> str:
|
| 35 |
+
"""Sanitize a Sheets URL.
|
| 36 |
+
|
| 37 |
+
Run some saftey checks to check whether `url` is a Sheets URL. This is not a
|
| 38 |
+
general-purpose URL sanitizer. Rather, it makes use of the fact that we know
|
| 39 |
+
the URL has to be for Sheets so we can make a few assumptions about (e.g. the
|
| 40 |
+
domain).
|
| 41 |
+
|
| 42 |
+
Args:
|
| 43 |
+
url: The url to sanitize.
|
| 44 |
+
|
| 45 |
+
Returns:
|
| 46 |
+
The sanitized url.
|
| 47 |
+
|
| 48 |
+
Raises:
|
| 49 |
+
ValueError: If `url` does not match the expected restrictions for a Sheets
|
| 50 |
+
URL.
|
| 51 |
+
"""
|
| 52 |
+
parse_result = parse.urlparse(url)
|
| 53 |
+
if parse_result.scheme != "https":
|
| 54 |
+
raise ValueError(
|
| 55 |
+
'Scheme for Sheets url must be "https", got "{}"'.format(parse_result.scheme)
|
| 56 |
+
)
|
| 57 |
+
if parse_result.netloc not in ("docs.google.com", "sheets.googleapis.com"):
|
| 58 |
+
raise ValueError(
|
| 59 |
+
'Domain for Sheets url must be "docs.google.com", got "{}"'.format(parse_result.netloc)
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
# Path component.
|
| 63 |
+
try:
|
| 64 |
+
for fragment in parse_result.path.split("/"):
|
| 65 |
+
_validate_url_part(fragment)
|
| 66 |
+
except ValueError as exc:
|
| 67 |
+
raise ValueError('Invalid path for Sheets url, got "{}"'.format(parse_result.path)) from exc
|
| 68 |
+
|
| 69 |
+
# Params component.
|
| 70 |
+
if parse_result.params:
|
| 71 |
+
raise ValueError('Params component must be empty, got "{}"'.format(parse_result.params))
|
| 72 |
+
|
| 73 |
+
# Query component.
|
| 74 |
+
try:
|
| 75 |
+
_validate_url_query_or_fragment(parse_result.query)
|
| 76 |
+
except ValueError as exc:
|
| 77 |
+
raise ValueError(
|
| 78 |
+
'Invalid query for Sheets url, got "{}"'.format(parse_result.query)
|
| 79 |
+
) from exc
|
| 80 |
+
|
| 81 |
+
# Fragment component.
|
| 82 |
+
try:
|
| 83 |
+
_validate_url_query_or_fragment(parse_result.fragment)
|
| 84 |
+
except ValueError as exc:
|
| 85 |
+
raise ValueError(
|
| 86 |
+
'Invalid fragment for Sheets url, got "{}"'.format(parse_result.fragment)
|
| 87 |
+
) from exc
|
| 88 |
+
|
| 89 |
+
return url
|
.venv/lib/python3.11/site-packages/google/generativeai/notebook/sheets_utils.py
ADDED
|
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Copyright 2023 Google LLC
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""SheetsInputs."""
|
| 16 |
+
from __future__ import annotations
|
| 17 |
+
|
| 18 |
+
from typing import Any, Callable, Mapping, Sequence
|
| 19 |
+
from urllib import parse
|
| 20 |
+
from google.generativeai.notebook import gspread_client
|
| 21 |
+
from google.generativeai.notebook import sheets_id
|
| 22 |
+
from google.generativeai.notebook.lib import llmfn_inputs_source
|
| 23 |
+
from google.generativeai.notebook.lib import llmfn_outputs
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def _try_sheet_id_as_url(value: str) -> sheets_id.SheetsIdentifier | None:
|
| 27 |
+
"""Try to open a Sheets document with `value` as a URL."""
|
| 28 |
+
try:
|
| 29 |
+
parse_result = parse.urlparse(value)
|
| 30 |
+
except ValueError:
|
| 31 |
+
# If there's a URL parsing error, then it's not a URL.
|
| 32 |
+
return None
|
| 33 |
+
|
| 34 |
+
if parse_result.scheme:
|
| 35 |
+
# If it looks like a URL, try to open the document as a URL but don't fall
|
| 36 |
+
# back to trying as key or name since it's very unlikely that a key or name
|
| 37 |
+
# looks like a URL.
|
| 38 |
+
sid = sheets_id.SheetsIdentifier(url=sheets_id.SheetsURL(value))
|
| 39 |
+
gspread_client.get_client().validate(sid)
|
| 40 |
+
return sid
|
| 41 |
+
|
| 42 |
+
return None
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def _try_sheet_id_as_key(value: str) -> sheets_id.SheetsIdentifier | None:
|
| 46 |
+
"""Try to open a Sheets document with `value` as a key."""
|
| 47 |
+
try:
|
| 48 |
+
sid = sheets_id.SheetsIdentifier(key=sheets_id.SheetsKey(value))
|
| 49 |
+
except ValueError:
|
| 50 |
+
# `value` is not a well-formed Sheets key.
|
| 51 |
+
return None
|
| 52 |
+
|
| 53 |
+
try:
|
| 54 |
+
gspread_client.get_client().validate(sid)
|
| 55 |
+
except gspread_client.SpreadsheetNotFoundError:
|
| 56 |
+
return None
|
| 57 |
+
return sid
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def _try_sheet_id_as_name(value: str) -> sheets_id.SheetsIdentifier | None:
|
| 61 |
+
"""Try to open a Sheets document with `value` as a name."""
|
| 62 |
+
sid = sheets_id.SheetsIdentifier(name=value)
|
| 63 |
+
try:
|
| 64 |
+
gspread_client.get_client().validate(sid)
|
| 65 |
+
except gspread_client.SpreadsheetNotFoundError:
|
| 66 |
+
return None
|
| 67 |
+
return sid
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def get_sheets_id_from_str(value: str) -> sheets_id.SheetsIdentifier:
|
| 71 |
+
if sid := _try_sheet_id_as_url(value):
|
| 72 |
+
return sid
|
| 73 |
+
if sid := _try_sheet_id_as_key(value):
|
| 74 |
+
return sid
|
| 75 |
+
if sid := _try_sheet_id_as_name(value):
|
| 76 |
+
return sid
|
| 77 |
+
raise RuntimeError('No Sheets found with "{}" as URL, key or name'.format(value))
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
class SheetsInputs(llmfn_inputs_source.LLMFnInputsSource):
|
| 81 |
+
"""Inputs to an LLMFunction from Google Sheets."""
|
| 82 |
+
|
| 83 |
+
def __init__(self, sid: sheets_id.SheetsIdentifier, worksheet_id: int = 0):
|
| 84 |
+
super().__init__()
|
| 85 |
+
self._sid = sid
|
| 86 |
+
self._worksheet_id = worksheet_id
|
| 87 |
+
|
| 88 |
+
def _to_normalized_inputs_impl(
|
| 89 |
+
self,
|
| 90 |
+
) -> tuple[Sequence[Mapping[str, str]], Callable[[], None]]:
|
| 91 |
+
return gspread_client.get_client().get_all_records(
|
| 92 |
+
sid=self._sid, worksheet_id=self._worksheet_id
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
class SheetsOutputs(llmfn_outputs.LLMFnOutputsSink):
|
| 97 |
+
"""Writes outputs from an LLMFunction to Google Sheets."""
|
| 98 |
+
|
| 99 |
+
def __init__(self, sid: sheets_id.SheetsIdentifier):
|
| 100 |
+
self._sid = sid
|
| 101 |
+
|
| 102 |
+
def write_outputs(self, outputs: llmfn_outputs.LLMFnOutputsBase) -> None:
|
| 103 |
+
# Transpose `outputs` into a list of rows.
|
| 104 |
+
outputs_dict = outputs.as_dict()
|
| 105 |
+
outputs_rows: list[Sequence[Any]] = [list(outputs_dict.keys())]
|
| 106 |
+
outputs_rows.extend([list(x) for x in zip(*outputs_dict.values())])
|
| 107 |
+
|
| 108 |
+
gspread_client.get_client().write_records(
|
| 109 |
+
sid=self._sid,
|
| 110 |
+
rows=outputs_rows,
|
| 111 |
+
)
|
.venv/lib/python3.11/site-packages/google/generativeai/notebook/text_model.py
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Copyright 2023 Google LLC
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Model that uses the Text service."""
|
| 16 |
+
from __future__ import annotations
|
| 17 |
+
|
| 18 |
+
from google.api_core import retry
|
| 19 |
+
import google.generativeai as genai
|
| 20 |
+
from google.generativeai.types import generation_types
|
| 21 |
+
from google.generativeai.notebook.lib import model as model_lib
|
| 22 |
+
|
| 23 |
+
_DEFAULT_MODEL = "models/gemini-1.5-flash"
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class TextModel(model_lib.AbstractModel):
|
| 27 |
+
"""Concrete model that uses the generate_content service."""
|
| 28 |
+
|
| 29 |
+
def _generate_text(
|
| 30 |
+
self,
|
| 31 |
+
prompt: str,
|
| 32 |
+
model: str | None = None,
|
| 33 |
+
temperature: float | None = None,
|
| 34 |
+
candidate_count: int | None = None,
|
| 35 |
+
) -> generation_types.GenerateContentResponse:
|
| 36 |
+
gen_config = {}
|
| 37 |
+
if temperature is not None:
|
| 38 |
+
gen_config["temperature"] = temperature
|
| 39 |
+
if candidate_count is not None:
|
| 40 |
+
gen_config["candidate_count"] = candidate_count
|
| 41 |
+
|
| 42 |
+
model_name = model or _DEFAULT_MODEL
|
| 43 |
+
gen_model = genai.GenerativeModel(model_name=model_name)
|
| 44 |
+
gc = genai.types.generation_types.GenerationConfig(**gen_config)
|
| 45 |
+
return gen_model.generate_content(prompt, generation_config=gc)
|
| 46 |
+
|
| 47 |
+
def call_model(
|
| 48 |
+
self,
|
| 49 |
+
model_input: str,
|
| 50 |
+
model_args: model_lib.ModelArguments | None = None,
|
| 51 |
+
) -> model_lib.ModelResults:
|
| 52 |
+
if model_args is None:
|
| 53 |
+
model_args = model_lib.ModelArguments()
|
| 54 |
+
|
| 55 |
+
# Wrap the generation function here, rather than decorate, so that it
|
| 56 |
+
# applies to any overridden calls too.
|
| 57 |
+
retryable_fn = retry.Retry(retry.if_transient_error)(self._generate_text)
|
| 58 |
+
response = retryable_fn(
|
| 59 |
+
prompt=model_input,
|
| 60 |
+
model=model_args.model,
|
| 61 |
+
temperature=model_args.temperature,
|
| 62 |
+
candidate_count=model_args.candidate_count,
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
text_outputs = []
|
| 66 |
+
for c in response.candidates:
|
| 67 |
+
text_outputs.append("".join(p.text for p in c.content.parts))
|
| 68 |
+
|
| 69 |
+
return model_lib.ModelResults(
|
| 70 |
+
model_input=model_input,
|
| 71 |
+
text_results=text_outputs,
|
| 72 |
+
)
|