koichi12 commited on
Commit
1f6807a
·
verified ·
1 Parent(s): 4fbd6f4

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .venv/lib/python3.11/site-packages/google/cloud/__pycache__/extended_operations_pb2.cpython-311.pyc +0 -0
  2. .venv/lib/python3.11/site-packages/google/cloud/extended_operations.proto +150 -0
  3. .venv/lib/python3.11/site-packages/google/cloud/extended_operations_pb2.py +47 -0
  4. .venv/lib/python3.11/site-packages/google/cloud/location/__pycache__/locations_pb2.cpython-311.pyc +0 -0
  5. .venv/lib/python3.11/site-packages/google/cloud/location/locations.proto +108 -0
  6. .venv/lib/python3.11/site-packages/google/cloud/location/locations_pb2.py +71 -0
  7. .venv/lib/python3.11/site-packages/google/generativeai/__init__.py +86 -0
  8. .venv/lib/python3.11/site-packages/google/generativeai/answer.py +363 -0
  9. .venv/lib/python3.11/site-packages/google/generativeai/audio_models/__init__.py +0 -0
  10. .venv/lib/python3.11/site-packages/google/generativeai/audio_models/__pycache__/__init__.cpython-311.pyc +0 -0
  11. .venv/lib/python3.11/site-packages/google/generativeai/audio_models/__pycache__/_audio_models.cpython-311.pyc +0 -0
  12. .venv/lib/python3.11/site-packages/google/generativeai/audio_models/_audio_models.py +3 -0
  13. .venv/lib/python3.11/site-packages/google/generativeai/caching.py +314 -0
  14. .venv/lib/python3.11/site-packages/google/generativeai/client.py +388 -0
  15. .venv/lib/python3.11/site-packages/google/generativeai/discuss.py +590 -0
  16. .venv/lib/python3.11/site-packages/google/generativeai/embedding.py +312 -0
  17. .venv/lib/python3.11/site-packages/google/generativeai/files.py +116 -0
  18. .venv/lib/python3.11/site-packages/google/generativeai/generative_models.py +875 -0
  19. .venv/lib/python3.11/site-packages/google/generativeai/models.py +467 -0
  20. .venv/lib/python3.11/site-packages/google/generativeai/notebook/__init__.py +32 -0
  21. .venv/lib/python3.11/site-packages/google/generativeai/notebook/__pycache__/command.cpython-311.pyc +0 -0
  22. .venv/lib/python3.11/site-packages/google/generativeai/notebook/__pycache__/compare_cmd.cpython-311.pyc +0 -0
  23. .venv/lib/python3.11/site-packages/google/generativeai/notebook/__pycache__/eval_cmd.cpython-311.pyc +0 -0
  24. .venv/lib/python3.11/site-packages/google/generativeai/notebook/__pycache__/flag_def.cpython-311.pyc +0 -0
  25. .venv/lib/python3.11/site-packages/google/generativeai/notebook/__pycache__/input_utils.cpython-311.pyc +0 -0
  26. .venv/lib/python3.11/site-packages/google/generativeai/notebook/__pycache__/magics_engine.cpython-311.pyc +0 -0
  27. .venv/lib/python3.11/site-packages/google/generativeai/notebook/__pycache__/parsed_args_lib.cpython-311.pyc +0 -0
  28. .venv/lib/python3.11/site-packages/google/generativeai/notebook/__pycache__/run_cmd.cpython-311.pyc +0 -0
  29. .venv/lib/python3.11/site-packages/google/generativeai/notebook/__pycache__/sheets_id.cpython-311.pyc +0 -0
  30. .venv/lib/python3.11/site-packages/google/generativeai/notebook/__pycache__/sheets_utils.cpython-311.pyc +0 -0
  31. .venv/lib/python3.11/site-packages/google/generativeai/notebook/cmd_line_parser.py +557 -0
  32. .venv/lib/python3.11/site-packages/google/generativeai/notebook/command.py +45 -0
  33. .venv/lib/python3.11/site-packages/google/generativeai/notebook/compare_cmd.py +71 -0
  34. .venv/lib/python3.11/site-packages/google/generativeai/notebook/compile_cmd.py +66 -0
  35. .venv/lib/python3.11/site-packages/google/generativeai/notebook/eval_cmd.py +75 -0
  36. .venv/lib/python3.11/site-packages/google/generativeai/notebook/flag_def.py +503 -0
  37. .venv/lib/python3.11/site-packages/google/generativeai/notebook/gspread_client.py +232 -0
  38. .venv/lib/python3.11/site-packages/google/generativeai/notebook/html_utils.py +46 -0
  39. .venv/lib/python3.11/site-packages/google/generativeai/notebook/input_utils.py +73 -0
  40. .venv/lib/python3.11/site-packages/google/generativeai/notebook/ipython_env_impl.py +30 -0
  41. .venv/lib/python3.11/site-packages/google/generativeai/notebook/magics.py +150 -0
  42. .venv/lib/python3.11/site-packages/google/generativeai/notebook/magics_engine.py +124 -0
  43. .venv/lib/python3.11/site-packages/google/generativeai/notebook/model_registry.py +55 -0
  44. .venv/lib/python3.11/site-packages/google/generativeai/notebook/parsed_args_lib.py +85 -0
  45. .venv/lib/python3.11/site-packages/google/generativeai/notebook/post_process_utils_test_helper.py +32 -0
  46. .venv/lib/python3.11/site-packages/google/generativeai/notebook/run_cmd.py +75 -0
  47. .venv/lib/python3.11/site-packages/google/generativeai/notebook/sheets_id.py +101 -0
  48. .venv/lib/python3.11/site-packages/google/generativeai/notebook/sheets_sanitize_url.py +89 -0
  49. .venv/lib/python3.11/site-packages/google/generativeai/notebook/sheets_utils.py +111 -0
  50. .venv/lib/python3.11/site-packages/google/generativeai/notebook/text_model.py +72 -0
.venv/lib/python3.11/site-packages/google/cloud/__pycache__/extended_operations_pb2.cpython-311.pyc ADDED
Binary file (2.24 kB). View file
 
.venv/lib/python3.11/site-packages/google/cloud/extended_operations.proto ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright 2024 Google LLC
2
+ //
3
+ // Licensed under the Apache License, Version 2.0 (the "License");
4
+ // you may not use this file except in compliance with the License.
5
+ // You may obtain a copy of the License at
6
+ //
7
+ // http://www.apache.org/licenses/LICENSE-2.0
8
+ //
9
+ // Unless required by applicable law or agreed to in writing, software
10
+ // distributed under the License is distributed on an "AS IS" BASIS,
11
+ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ // See the License for the specific language governing permissions and
13
+ // limitations under the License.
14
+
15
+ // This file contains custom annotations that are used by GAPIC generators to
16
+ // handle Long Running Operation methods (LRO) that are NOT compliant with
17
+ // https://google.aip.dev/151. These annotations are public for technical
18
+ // reasons only. Please DO NOT USE them in your protos.
19
+ syntax = "proto3";
20
+
21
+ package google.cloud;
22
+
23
+ import "google/protobuf/descriptor.proto";
24
+
25
+ option go_package = "google.golang.org/genproto/googleapis/cloud/extendedops;extendedops";
26
+ option java_multiple_files = true;
27
+ option java_outer_classname = "ExtendedOperationsProto";
28
+ option java_package = "com.google.cloud";
29
+ option objc_class_prefix = "GAPI";
30
+
31
+ // FieldOptions to match corresponding fields in the initial request,
32
+ // polling request and operation response messages.
33
+ //
34
+ // Example:
35
+ //
36
+ // In an API-specific operation message:
37
+ //
38
+ // message MyOperation {
39
+ // string http_error_message = 1 [(operation_field) = ERROR_MESSAGE];
40
+ // int32 http_error_status_code = 2 [(operation_field) = ERROR_CODE];
41
+ // string id = 3 [(operation_field) = NAME];
42
+ // Status status = 4 [(operation_field) = STATUS];
43
+ // }
44
+ //
45
+ // In a polling request message (the one which is used to poll for an LRO
46
+ // status):
47
+ //
48
+ // message MyPollingRequest {
49
+ // string operation = 1 [(operation_response_field) = "id"];
50
+ // string project = 2;
51
+ // string region = 3;
52
+ // }
53
+ //
54
+ // In an initial request message (the one which starts an LRO):
55
+ //
56
+ // message MyInitialRequest {
57
+ // string my_project = 2 [(operation_request_field) = "project"];
58
+ // string my_region = 3 [(operation_request_field) = "region"];
59
+ // }
60
+ //
61
+ extend google.protobuf.FieldOptions {
62
+ // A field annotation that maps fields in an API-specific Operation object to
63
+ // their standard counterparts in google.longrunning.Operation. See
64
+ // OperationResponseMapping enum definition.
65
+ OperationResponseMapping operation_field = 1149;
66
+
67
+ // A field annotation that maps fields in the initial request message
68
+ // (the one which started the LRO) to their counterparts in the polling
69
+ // request message. For non-standard LRO, the polling response may be missing
70
+ // some of the information needed to make a subsequent polling request. The
71
+ // missing information (for example, project or region ID) is contained in the
72
+ // fields of the initial request message that this annotation must be applied
73
+ // to. The string value of the annotation corresponds to the name of the
74
+ // counterpart field in the polling request message that the annotated field's
75
+ // value will be copied to.
76
+ string operation_request_field = 1150;
77
+
78
+ // A field annotation that maps fields in the polling request message to their
79
+ // counterparts in the initial and/or polling response message. The initial
80
+ // and the polling methods return an API-specific Operation object. Some of
81
+ // the fields from that response object must be reused in the subsequent
82
+ // request (like operation name/ID) to fully identify the polled operation.
83
+ // This annotation must be applied to the fields in the polling request
84
+ // message, the string value of the annotation must correspond to the name of
85
+ // the counterpart field in the Operation response object whose value will be
86
+ // copied to the annotated field.
87
+ string operation_response_field = 1151;
88
+ }
89
+
90
+ // MethodOptions to identify the actual service and method used for operation
91
+ // status polling.
92
+ //
93
+ // Example:
94
+ //
95
+ // In a method, which starts an LRO:
96
+ //
97
+ // service MyService {
98
+ // rpc Foo(MyInitialRequest) returns (MyOperation) {
99
+ // option (operation_service) = "MyPollingService";
100
+ // }
101
+ // }
102
+ //
103
+ // In a polling method:
104
+ //
105
+ // service MyPollingService {
106
+ // rpc Get(MyPollingRequest) returns (MyOperation) {
107
+ // option (operation_polling_method) = true;
108
+ // }
109
+ // }
110
+ extend google.protobuf.MethodOptions {
111
+ // A method annotation that maps an LRO method (the one which starts an LRO)
112
+ // to the service, which will be used to poll for the operation status. The
113
+ // annotation must be applied to the method which starts an LRO, the string
114
+ // value of the annotation must correspond to the name of the service used to
115
+ // poll for the operation status.
116
+ string operation_service = 1249;
117
+
118
+ // A method annotation that marks methods that can be used for polling
119
+ // operation status (e.g. the MyPollingService.Get(MyPollingRequest) method).
120
+ bool operation_polling_method = 1250;
121
+ }
122
+
123
+ // An enum to be used to mark the essential (for polling) fields in an
124
+ // API-specific Operation object. A custom Operation object may contain many
125
+ // different fields, but only few of them are essential to conduct a successful
126
+ // polling process.
127
+ enum OperationResponseMapping {
128
+ // Do not use.
129
+ UNDEFINED = 0;
130
+
131
+ // A field in an API-specific (custom) Operation object which carries the same
132
+ // meaning as google.longrunning.Operation.name.
133
+ NAME = 1;
134
+
135
+ // A field in an API-specific (custom) Operation object which carries the same
136
+ // meaning as google.longrunning.Operation.done. If the annotated field is of
137
+ // an enum type, `annotated_field_name == EnumType.DONE` semantics should be
138
+ // equivalent to `Operation.done == true`. If the annotated field is of type
139
+ // boolean, then it should follow the same semantics as Operation.done.
140
+ // Otherwise, a non-empty value should be treated as `Operation.done == true`.
141
+ STATUS = 2;
142
+
143
+ // A field in an API-specific (custom) Operation object which carries the same
144
+ // meaning as google.longrunning.Operation.error.code.
145
+ ERROR_CODE = 3;
146
+
147
+ // A field in an API-specific (custom) Operation object which carries the same
148
+ // meaning as google.longrunning.Operation.error.message.
149
+ ERROR_MESSAGE = 4;
150
+ }
.venv/lib/python3.11/site-packages/google/cloud/extended_operations_pb2.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ # Copyright 2024 Google LLC
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ # Generated by the protocol buffer compiler. DO NOT EDIT!
18
+ # source: google/cloud/extended_operations.proto
19
+ """Generated protocol buffer code."""
20
+ from google.protobuf import descriptor as _descriptor
21
+ from google.protobuf import descriptor_pool as _descriptor_pool
22
+ from google.protobuf import symbol_database as _symbol_database
23
+ from google.protobuf.internal import builder as _builder
24
+
25
+ # @@protoc_insertion_point(imports)
26
+
27
+ _sym_db = _symbol_database.Default()
28
+
29
+
30
+ from google.protobuf import descriptor_pb2 as google_dot_protobuf_dot_descriptor__pb2
31
+
32
+
33
+ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
34
+ b"\n&google/cloud/extended_operations.proto\x12\x0cgoogle.cloud\x1a google/protobuf/descriptor.proto*b\n\x18OperationResponseMapping\x12\r\n\tUNDEFINED\x10\x00\x12\x08\n\x04NAME\x10\x01\x12\n\n\x06STATUS\x10\x02\x12\x0e\n\nERROR_CODE\x10\x03\x12\x11\n\rERROR_MESSAGE\x10\x04:_\n\x0foperation_field\x12\x1d.google.protobuf.FieldOptions\x18\xfd\x08 \x01(\x0e\x32&.google.cloud.OperationResponseMapping:?\n\x17operation_request_field\x12\x1d.google.protobuf.FieldOptions\x18\xfe\x08 \x01(\t:@\n\x18operation_response_field\x12\x1d.google.protobuf.FieldOptions\x18\xff\x08 \x01(\t::\n\x11operation_service\x12\x1e.google.protobuf.MethodOptions\x18\xe1\t \x01(\t:A\n\x18operation_polling_method\x12\x1e.google.protobuf.MethodOptions\x18\xe2\t \x01(\x08\x42y\n\x10\x63om.google.cloudB\x17\x45xtendedOperationsProtoP\x01ZCgoogle.golang.org/genproto/googleapis/cloud/extendedops;extendedops\xa2\x02\x04GAPIb\x06proto3"
35
+ )
36
+
37
+ _globals = globals()
38
+ _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
39
+ _builder.BuildTopDescriptorsAndMessages(
40
+ DESCRIPTOR, "google.cloud.extended_operations_pb2", _globals
41
+ )
42
+ if _descriptor._USE_C_DESCRIPTORS == False:
43
+ DESCRIPTOR._options = None
44
+ DESCRIPTOR._serialized_options = b"\n\020com.google.cloudB\027ExtendedOperationsProtoP\001ZCgoogle.golang.org/genproto/googleapis/cloud/extendedops;extendedops\242\002\004GAPI"
45
+ _globals["_OPERATIONRESPONSEMAPPING"]._serialized_start = 90
46
+ _globals["_OPERATIONRESPONSEMAPPING"]._serialized_end = 188
47
+ # @@protoc_insertion_point(module_scope)
.venv/lib/python3.11/site-packages/google/cloud/location/__pycache__/locations_pb2.cpython-311.pyc ADDED
Binary file (3.99 kB). View file
 
.venv/lib/python3.11/site-packages/google/cloud/location/locations.proto ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright 2024 Google LLC
2
+ //
3
+ // Licensed under the Apache License, Version 2.0 (the "License");
4
+ // you may not use this file except in compliance with the License.
5
+ // You may obtain a copy of the License at
6
+ //
7
+ // http://www.apache.org/licenses/LICENSE-2.0
8
+ //
9
+ // Unless required by applicable law or agreed to in writing, software
10
+ // distributed under the License is distributed on an "AS IS" BASIS,
11
+ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ // See the License for the specific language governing permissions and
13
+ // limitations under the License.
14
+
15
+ syntax = "proto3";
16
+
17
+ package google.cloud.location;
18
+
19
+ import "google/api/annotations.proto";
20
+ import "google/protobuf/any.proto";
21
+ import "google/api/client.proto";
22
+
23
+ option cc_enable_arenas = true;
24
+ option go_package = "google.golang.org/genproto/googleapis/cloud/location;location";
25
+ option java_multiple_files = true;
26
+ option java_outer_classname = "LocationsProto";
27
+ option java_package = "com.google.cloud.location";
28
+
29
+ // An abstract interface that provides location-related information for
30
+ // a service. Service-specific metadata is provided through the
31
+ // [Location.metadata][google.cloud.location.Location.metadata] field.
32
+ service Locations {
33
+ option (google.api.default_host) = "cloud.googleapis.com";
34
+ option (google.api.oauth_scopes) = "https://www.googleapis.com/auth/cloud-platform";
35
+
36
+ // Lists information about the supported locations for this service.
37
+ rpc ListLocations(ListLocationsRequest) returns (ListLocationsResponse) {
38
+ option (google.api.http) = {
39
+ get: "/v1/{name=locations}"
40
+ additional_bindings {
41
+ get: "/v1/{name=projects/*}/locations"
42
+ }
43
+ };
44
+ }
45
+
46
+ // Gets information about a location.
47
+ rpc GetLocation(GetLocationRequest) returns (Location) {
48
+ option (google.api.http) = {
49
+ get: "/v1/{name=locations/*}"
50
+ additional_bindings {
51
+ get: "/v1/{name=projects/*/locations/*}"
52
+ }
53
+ };
54
+ }
55
+ }
56
+
57
+ // The request message for [Locations.ListLocations][google.cloud.location.Locations.ListLocations].
58
+ message ListLocationsRequest {
59
+ // The resource that owns the locations collection, if applicable.
60
+ string name = 1;
61
+
62
+ // The standard list filter.
63
+ string filter = 2;
64
+
65
+ // The standard list page size.
66
+ int32 page_size = 3;
67
+
68
+ // The standard list page token.
69
+ string page_token = 4;
70
+ }
71
+
72
+ // The response message for [Locations.ListLocations][google.cloud.location.Locations.ListLocations].
73
+ message ListLocationsResponse {
74
+ // A list of locations that matches the specified filter in the request.
75
+ repeated Location locations = 1;
76
+
77
+ // The standard List next-page token.
78
+ string next_page_token = 2;
79
+ }
80
+
81
+ // The request message for [Locations.GetLocation][google.cloud.location.Locations.GetLocation].
82
+ message GetLocationRequest {
83
+ // Resource name for the location.
84
+ string name = 1;
85
+ }
86
+
87
+ // A resource that represents Google Cloud Platform location.
88
+ message Location {
89
+ // Resource name for the location, which may vary between implementations.
90
+ // For example: `"projects/example-project/locations/us-east1"`
91
+ string name = 1;
92
+
93
+ // The canonical id for this location. For example: `"us-east1"`.
94
+ string location_id = 4;
95
+
96
+ // The friendly name for this location, typically a nearby city name.
97
+ // For example, "Tokyo".
98
+ string display_name = 5;
99
+
100
+ // Cross-service attributes for the location. For example
101
+ //
102
+ // {"cloud.googleapis.com/region": "us-east1"}
103
+ map<string, string> labels = 2;
104
+
105
+ // Service-specific metadata. For example the available capacity at the given
106
+ // location.
107
+ google.protobuf.Any metadata = 3;
108
+ }
.venv/lib/python3.11/site-packages/google/cloud/location/locations_pb2.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ # Copyright 2024 Google LLC
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ # Generated by the protocol buffer compiler. DO NOT EDIT!
18
+ # source: google/cloud/location/locations.proto
19
+ """Generated protocol buffer code."""
20
+ from google.protobuf import descriptor as _descriptor
21
+ from google.protobuf import descriptor_pool as _descriptor_pool
22
+ from google.protobuf import symbol_database as _symbol_database
23
+ from google.protobuf.internal import builder as _builder
24
+
25
+ # @@protoc_insertion_point(imports)
26
+
27
+ _sym_db = _symbol_database.Default()
28
+
29
+
30
+ from google.api import annotations_pb2 as google_dot_api_dot_annotations__pb2
31
+ from google.protobuf import any_pb2 as google_dot_protobuf_dot_any__pb2
32
+ from google.api import client_pb2 as google_dot_api_dot_client__pb2
33
+
34
+
35
+ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
36
+ b'\n%google/cloud/location/locations.proto\x12\x15google.cloud.location\x1a\x1cgoogle/api/annotations.proto\x1a\x19google/protobuf/any.proto\x1a\x17google/api/client.proto"[\n\x14ListLocationsRequest\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x0e\n\x06\x66ilter\x18\x02 \x01(\t\x12\x11\n\tpage_size\x18\x03 \x01(\x05\x12\x12\n\npage_token\x18\x04 \x01(\t"d\n\x15ListLocationsResponse\x12\x32\n\tlocations\x18\x01 \x03(\x0b\x32\x1f.google.cloud.location.Location\x12\x17\n\x0fnext_page_token\x18\x02 \x01(\t""\n\x12GetLocationRequest\x12\x0c\n\x04name\x18\x01 \x01(\t"\xd7\x01\n\x08Location\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x13\n\x0blocation_id\x18\x04 \x01(\t\x12\x14\n\x0c\x64isplay_name\x18\x05 \x01(\t\x12;\n\x06labels\x18\x02 \x03(\x0b\x32+.google.cloud.location.Location.LabelsEntry\x12&\n\x08metadata\x18\x03 \x01(\x0b\x32\x14.google.protobuf.Any\x1a-\n\x0bLabelsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\x32\xa4\x03\n\tLocations\x12\xab\x01\n\rListLocations\x12+.google.cloud.location.ListLocationsRequest\x1a,.google.cloud.location.ListLocationsResponse"?\x82\xd3\xe4\x93\x02\x39\x12\x14/v1/{name=locations}Z!\x12\x1f/v1/{name=projects/*}/locations\x12\x9e\x01\n\x0bGetLocation\x12).google.cloud.location.GetLocationRequest\x1a\x1f.google.cloud.location.Location"C\x82\xd3\xe4\x93\x02=\x12\x16/v1/{name=locations/*}Z#\x12!/v1/{name=projects/*/locations/*}\x1aH\xca\x41\x14\x63loud.googleapis.com\xd2\x41.https://www.googleapis.com/auth/cloud-platformBo\n\x19\x63om.google.cloud.locationB\x0eLocationsProtoP\x01Z=google.golang.org/genproto/googleapis/cloud/location;location\xf8\x01\x01\x62\x06proto3'
37
+ )
38
+
39
+ _globals = globals()
40
+ _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
41
+ _builder.BuildTopDescriptorsAndMessages(
42
+ DESCRIPTOR, "google.cloud.location.locations_pb2", _globals
43
+ )
44
+ if _descriptor._USE_C_DESCRIPTORS == False:
45
+ DESCRIPTOR._options = None
46
+ DESCRIPTOR._serialized_options = b"\n\031com.google.cloud.locationB\016LocationsProtoP\001Z=google.golang.org/genproto/googleapis/cloud/location;location\370\001\001"
47
+ _LOCATION_LABELSENTRY._options = None
48
+ _LOCATION_LABELSENTRY._serialized_options = b"8\001"
49
+ _LOCATIONS._options = None
50
+ _LOCATIONS._serialized_options = b"\312A\024cloud.googleapis.com\322A.https://www.googleapis.com/auth/cloud-platform"
51
+ _LOCATIONS.methods_by_name["ListLocations"]._options = None
52
+ _LOCATIONS.methods_by_name[
53
+ "ListLocations"
54
+ ]._serialized_options = b"\202\323\344\223\0029\022\024/v1/{name=locations}Z!\022\037/v1/{name=projects/*}/locations"
55
+ _LOCATIONS.methods_by_name["GetLocation"]._options = None
56
+ _LOCATIONS.methods_by_name[
57
+ "GetLocation"
58
+ ]._serialized_options = b"\202\323\344\223\002=\022\026/v1/{name=locations/*}Z#\022!/v1/{name=projects/*/locations/*}"
59
+ _globals["_LISTLOCATIONSREQUEST"]._serialized_start = 146
60
+ _globals["_LISTLOCATIONSREQUEST"]._serialized_end = 237
61
+ _globals["_LISTLOCATIONSRESPONSE"]._serialized_start = 239
62
+ _globals["_LISTLOCATIONSRESPONSE"]._serialized_end = 339
63
+ _globals["_GETLOCATIONREQUEST"]._serialized_start = 341
64
+ _globals["_GETLOCATIONREQUEST"]._serialized_end = 375
65
+ _globals["_LOCATION"]._serialized_start = 378
66
+ _globals["_LOCATION"]._serialized_end = 593
67
+ _globals["_LOCATION_LABELSENTRY"]._serialized_start = 548
68
+ _globals["_LOCATION_LABELSENTRY"]._serialized_end = 593
69
+ _globals["_LOCATIONS"]._serialized_start = 596
70
+ _globals["_LOCATIONS"]._serialized_end = 1016
71
+ # @@protoc_insertion_point(module_scope)
.venv/lib/python3.11/site-packages/google/generativeai/__init__.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright 2023 Google LLC
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Google AI Python SDK
16
+
17
+ ## Setup
18
+
19
+ ```posix-terminal
20
+ pip install google-generativeai
21
+ ```
22
+
23
+ ## GenerativeModel
24
+
25
+ Use `genai.GenerativeModel` to access the API:
26
+
27
+ ```
28
+ import google.generativeai as genai
29
+ import os
30
+
31
+ genai.configure(api_key=os.environ['API_KEY'])
32
+
33
+ model = genai.GenerativeModel(model_name='gemini-1.5-flash')
34
+ response = model.generate_content('Teach me about how an LLM works')
35
+
36
+ print(response.text)
37
+ ```
38
+
39
+ See the [python quickstart](https://ai.google.dev/tutorials/python_quickstart) for more details.
40
+ """
41
+ from __future__ import annotations
42
+
43
+ from google.generativeai import version
44
+
45
+ from google.generativeai import caching
46
+ from google.generativeai import protos
47
+ from google.generativeai import types
48
+
49
+ from google.generativeai.client import configure
50
+
51
+ from google.generativeai.embedding import embed_content
52
+ from google.generativeai.embedding import embed_content_async
53
+
54
+ from google.generativeai.files import upload_file
55
+ from google.generativeai.files import get_file
56
+ from google.generativeai.files import list_files
57
+ from google.generativeai.files import delete_file
58
+
59
+ from google.generativeai.generative_models import GenerativeModel
60
+ from google.generativeai.generative_models import ChatSession
61
+
62
+ from google.generativeai.models import list_models
63
+ from google.generativeai.models import list_tuned_models
64
+
65
+ from google.generativeai.models import get_model
66
+ from google.generativeai.models import get_base_model
67
+ from google.generativeai.models import get_tuned_model
68
+
69
+ from google.generativeai.models import create_tuned_model
70
+ from google.generativeai.models import update_tuned_model
71
+ from google.generativeai.models import delete_tuned_model
72
+
73
+ from google.generativeai.operations import list_operations
74
+ from google.generativeai.operations import get_operation
75
+
76
+ from google.generativeai.types import GenerationConfig
77
+
78
+ __version__ = version.__version__
79
+
80
+ del embedding
81
+ del files
82
+ del generative_models
83
+ del models
84
+ del client
85
+ del operations
86
+ del version
.venv/lib/python3.11/site-packages/google/generativeai/answer.py ADDED
@@ -0,0 +1,363 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright 2023 Google LLC
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ from __future__ import annotations
16
+
17
+ import dataclasses
18
+ from collections.abc import Iterable
19
+ import itertools
20
+ from typing import Any, Iterable, Union, Mapping, Optional
21
+ from typing_extensions import TypedDict
22
+
23
+ import google.ai.generativelanguage as glm
24
+ from google.generativeai import protos
25
+
26
+ from google.generativeai.client import (
27
+ get_default_generative_client,
28
+ get_default_generative_async_client,
29
+ )
30
+ from google.generativeai.types import model_types
31
+ from google.generativeai.types import helper_types
32
+ from google.generativeai.types import safety_types
33
+ from google.generativeai.types import content_types
34
+ from google.generativeai.types import retriever_types
35
+ from google.generativeai.types.retriever_types import MetadataFilter
36
+
37
+ DEFAULT_ANSWER_MODEL = "models/aqa"
38
+
39
+ AnswerStyle = protos.GenerateAnswerRequest.AnswerStyle
40
+
41
+ AnswerStyleOptions = Union[int, str, AnswerStyle]
42
+
43
+ _ANSWER_STYLES: dict[AnswerStyleOptions, AnswerStyle] = {
44
+ AnswerStyle.ANSWER_STYLE_UNSPECIFIED: AnswerStyle.ANSWER_STYLE_UNSPECIFIED,
45
+ 0: AnswerStyle.ANSWER_STYLE_UNSPECIFIED,
46
+ "answer_style_unspecified": AnswerStyle.ANSWER_STYLE_UNSPECIFIED,
47
+ "unspecified": AnswerStyle.ANSWER_STYLE_UNSPECIFIED,
48
+ AnswerStyle.ABSTRACTIVE: AnswerStyle.ABSTRACTIVE,
49
+ 1: AnswerStyle.ABSTRACTIVE,
50
+ "answer_style_abstractive": AnswerStyle.ABSTRACTIVE,
51
+ "abstractive": AnswerStyle.ABSTRACTIVE,
52
+ AnswerStyle.EXTRACTIVE: AnswerStyle.EXTRACTIVE,
53
+ 2: AnswerStyle.EXTRACTIVE,
54
+ "answer_style_extractive": AnswerStyle.EXTRACTIVE,
55
+ "extractive": AnswerStyle.EXTRACTIVE,
56
+ AnswerStyle.VERBOSE: AnswerStyle.VERBOSE,
57
+ 3: AnswerStyle.VERBOSE,
58
+ "answer_style_verbose": AnswerStyle.VERBOSE,
59
+ "verbose": AnswerStyle.VERBOSE,
60
+ }
61
+
62
+
63
+ def to_answer_style(x: AnswerStyleOptions) -> AnswerStyle:
64
+ if isinstance(x, str):
65
+ x = x.lower()
66
+ return _ANSWER_STYLES[x]
67
+
68
+
69
+ GroundingPassageOptions = (
70
+ Union[
71
+ protos.GroundingPassage, tuple[str, content_types.ContentType], content_types.ContentType
72
+ ],
73
+ )
74
+
75
+ GroundingPassagesOptions = Union[
76
+ protos.GroundingPassages,
77
+ Iterable[GroundingPassageOptions],
78
+ Mapping[str, content_types.ContentType],
79
+ ]
80
+
81
+
82
+ def _make_grounding_passages(source: GroundingPassagesOptions) -> protos.GroundingPassages:
83
+ """
84
+ Converts the `source` into a `protos.GroundingPassage`. A `GroundingPassages` contains a list of
85
+ `protos.GroundingPassage` objects, which each contain a `protos.Content` and a string `id`.
86
+
87
+ Args:
88
+ source: `Content` or a `GroundingPassagesOptions` that will be converted to protos.GroundingPassages.
89
+
90
+ Return:
91
+ `protos.GroundingPassages` to be passed into `protos.GenerateAnswer`.
92
+ """
93
+ if isinstance(source, protos.GroundingPassages):
94
+ return source
95
+
96
+ if not isinstance(source, Iterable):
97
+ raise TypeError(
98
+ f"Invalid input: The 'source' argument must be an instance of 'GroundingPassagesOptions'. Received a '{type(source).__name__}' object instead."
99
+ )
100
+
101
+ passages = []
102
+ if isinstance(source, Mapping):
103
+ source = source.items()
104
+
105
+ for n, data in enumerate(source):
106
+ if isinstance(data, protos.GroundingPassage):
107
+ passages.append(data)
108
+ elif isinstance(data, tuple):
109
+ id, content = data # tuple must have exactly 2 items.
110
+ passages.append({"id": id, "content": content_types.to_content(content)})
111
+ else:
112
+ passages.append({"id": str(n), "content": content_types.to_content(data)})
113
+
114
+ return protos.GroundingPassages(passages=passages)
115
+
116
+
117
+ SourceNameType = Union[
118
+ str, retriever_types.Corpus, protos.Corpus, retriever_types.Document, protos.Document
119
+ ]
120
+
121
+
122
+ class SemanticRetrieverConfigDict(TypedDict):
123
+ source: SourceNameType
124
+ query: content_types.ContentsType
125
+ metadata_filter: Optional[Iterable[MetadataFilter]]
126
+ max_chunks_count: Optional[int]
127
+ minimum_relevance_score: Optional[float]
128
+
129
+
130
+ SemanticRetrieverConfigOptions = Union[
131
+ SourceNameType,
132
+ SemanticRetrieverConfigDict,
133
+ protos.SemanticRetrieverConfig,
134
+ ]
135
+
136
+
137
+ def _maybe_get_source_name(source) -> str | None:
138
+ if isinstance(source, str):
139
+ return source
140
+ elif isinstance(
141
+ source, (retriever_types.Corpus, protos.Corpus, retriever_types.Document, protos.Document)
142
+ ):
143
+ return source.name
144
+ else:
145
+ return None
146
+
147
+
148
+ def _make_semantic_retriever_config(
149
+ source: SemanticRetrieverConfigOptions,
150
+ query: content_types.ContentsType,
151
+ ) -> protos.SemanticRetrieverConfig:
152
+ if isinstance(source, protos.SemanticRetrieverConfig):
153
+ return source
154
+
155
+ name = _maybe_get_source_name(source)
156
+ if name is not None:
157
+ source = {"source": name}
158
+ elif isinstance(source, dict):
159
+ source["source"] = _maybe_get_source_name(source["source"])
160
+ else:
161
+ raise TypeError(
162
+ f"Invalid input: Failed to create a 'protos.SemanticRetrieverConfig' from the provided source. "
163
+ f"Received type: {type(source).__name__}, "
164
+ f"Received value: {source}"
165
+ )
166
+
167
+ if source["query"] is None:
168
+ source["query"] = query
169
+ elif isinstance(source["query"], str):
170
+ source["query"] = content_types.to_content(source["query"])
171
+
172
+ return protos.SemanticRetrieverConfig(source)
173
+
174
+
175
+ def _make_generate_answer_request(
176
+ *,
177
+ model: model_types.AnyModelNameOptions = DEFAULT_ANSWER_MODEL,
178
+ contents: content_types.ContentsType,
179
+ inline_passages: GroundingPassagesOptions | None = None,
180
+ semantic_retriever: SemanticRetrieverConfigOptions | None = None,
181
+ answer_style: AnswerStyle | None = None,
182
+ safety_settings: safety_types.SafetySettingOptions | None = None,
183
+ temperature: float | None = None,
184
+ ) -> protos.GenerateAnswerRequest:
185
+ """
186
+ constructs a protos.GenerateAnswerRequest object by organizing the input parameters for the API call to generate a grounded answer from the model.
187
+
188
+ Args:
189
+ model: Name of the model used to generate the grounded response.
190
+ contents: Content of the current conversation with the model. For single-turn query, this is a
191
+ single question to answer. For multi-turn queries, this is a repeated field that contains
192
+ conversation history and the last `Content` in the list containing the question.
193
+ inline_passages: Grounding passages (a list of `Content`-like objects or `(id, content)` pairs,
194
+ or a `protos.GroundingPassages`) to send inline with the request. Exclusive with `semantic_retriever`,
195
+ one must be set, but not both.
196
+ semantic_retriever: A Corpus, Document, or `protos.SemanticRetrieverConfig` to use for grounding. Exclusive with
197
+ `inline_passages`, one must be set, but not both.
198
+ answer_style: Style for grounded answers.
199
+ safety_settings: Safety settings for generated output.
200
+ temperature: The temperature for randomness in the output.
201
+
202
+ Returns:
203
+ Call for protos.GenerateAnswerRequest().
204
+ """
205
+ model = model_types.make_model_name(model)
206
+
207
+ contents = content_types.to_contents(contents)
208
+
209
+ if safety_settings:
210
+ safety_settings = safety_types.normalize_safety_settings(safety_settings)
211
+
212
+ if inline_passages is not None and semantic_retriever is not None:
213
+ raise ValueError(
214
+ f"Invalid configuration: Please set either 'inline_passages' or 'semantic_retriever_config', but not both. "
215
+ f"Received for inline_passages: {inline_passages}, and for semantic_retriever: {semantic_retriever}."
216
+ )
217
+ elif inline_passages is not None:
218
+ inline_passages = _make_grounding_passages(inline_passages)
219
+ elif semantic_retriever is not None:
220
+ semantic_retriever = _make_semantic_retriever_config(semantic_retriever, contents[-1])
221
+ else:
222
+ raise TypeError(
223
+ f"Invalid configuration: Either 'inline_passages' or 'semantic_retriever_config' must be provided, but currently both are 'None'. "
224
+ f"Received for inline_passages: {inline_passages}, and for semantic_retriever: {semantic_retriever}."
225
+ )
226
+
227
+ if answer_style:
228
+ answer_style = to_answer_style(answer_style)
229
+
230
+ return protos.GenerateAnswerRequest(
231
+ model=model,
232
+ contents=contents,
233
+ inline_passages=inline_passages,
234
+ semantic_retriever=semantic_retriever,
235
+ safety_settings=safety_settings,
236
+ temperature=temperature,
237
+ answer_style=answer_style,
238
+ )
239
+
240
+
241
+ def generate_answer(
242
+ *,
243
+ model: model_types.AnyModelNameOptions = DEFAULT_ANSWER_MODEL,
244
+ contents: content_types.ContentsType,
245
+ inline_passages: GroundingPassagesOptions | None = None,
246
+ semantic_retriever: SemanticRetrieverConfigOptions | None = None,
247
+ answer_style: AnswerStyle | None = None,
248
+ safety_settings: safety_types.SafetySettingOptions | None = None,
249
+ temperature: float | None = None,
250
+ client: glm.GenerativeServiceClient | None = None,
251
+ request_options: helper_types.RequestOptionsType | None = None,
252
+ ):
253
+ """Calls the GenerateAnswer API and returns a `types.Answer` containing the response.
254
+
255
+ You can pass a literal list of text chunks:
256
+
257
+ >>> from google.generativeai import answer
258
+ >>> answer.generate_answer(
259
+ ... content=question,
260
+ ... inline_passages=splitter.split(document)
261
+ ... )
262
+
263
+ Or pass a reference to a retreiver Document or Corpus:
264
+
265
+ >>> from google.generativeai import answer
266
+ >>> from google.generativeai import retriever
267
+ >>> my_corpus = retriever.get_corpus('my_corpus')
268
+ >>> genai.generate_answer(
269
+ ... content=question,
270
+ ... semantic_retriever=my_corpus
271
+ ... )
272
+
273
+
274
+ Args:
275
+ model: Which model to call, as a string or a `types.Model`.
276
+ contents: The question to be answered by the model, grounded in the
277
+ provided source.
278
+ inline_passages: Grounding passages (a list of `Content`-like objects or (id, content) pairs,
279
+ or a `protos.GroundingPassages`) to send inline with the request. Exclusive with `semantic_retriever`,
280
+ one must be set, but not both.
281
+ semantic_retriever: A Corpus, Document, or `protos.SemanticRetrieverConfig` to use for grounding. Exclusive with
282
+ `inline_passages`, one must be set, but not both.
283
+ answer_style: Style in which the grounded answer should be returned.
284
+ safety_settings: Safety settings for generated output. Defaults to None.
285
+ temperature: Controls the randomness of the output.
286
+ client: If you're not relying on a default client, you pass a `glm.GenerativeServiceClient` instead.
287
+ request_options: Options for the request.
288
+
289
+ Returns:
290
+ A `types.Answer` containing the model's text answer response.
291
+ """
292
+ if request_options is None:
293
+ request_options = {}
294
+
295
+ if client is None:
296
+ client = get_default_generative_client()
297
+
298
+ request = _make_generate_answer_request(
299
+ model=model,
300
+ contents=contents,
301
+ inline_passages=inline_passages,
302
+ semantic_retriever=semantic_retriever,
303
+ safety_settings=safety_settings,
304
+ temperature=temperature,
305
+ answer_style=answer_style,
306
+ )
307
+
308
+ response = client.generate_answer(request, **request_options)
309
+
310
+ return response
311
+
312
+
313
+ async def generate_answer_async(
314
+ *,
315
+ model: model_types.AnyModelNameOptions = DEFAULT_ANSWER_MODEL,
316
+ contents: content_types.ContentsType,
317
+ inline_passages: GroundingPassagesOptions | None = None,
318
+ semantic_retriever: SemanticRetrieverConfigOptions | None = None,
319
+ answer_style: AnswerStyle | None = None,
320
+ safety_settings: safety_types.SafetySettingOptions | None = None,
321
+ temperature: float | None = None,
322
+ client: glm.GenerativeServiceClient | None = None,
323
+ request_options: helper_types.RequestOptionsType | None = None,
324
+ ):
325
+ """
326
+ Calls the API and returns a `types.Answer` containing the answer.
327
+
328
+ Args:
329
+ model: Which model to call, as a string or a `types.Model`.
330
+ contents: The question to be answered by the model, grounded in the
331
+ provided source.
332
+ inline_passages: Grounding passages (a list of `Content`-like objects or (id, content) pairs,
333
+ or a `protos.GroundingPassages`) to send inline with the request. Exclusive with `semantic_retriever`,
334
+ one must be set, but not both.
335
+ semantic_retriever: A Corpus, Document, or `protos.SemanticRetrieverConfig` to use for grounding. Exclusive with
336
+ `inline_passages`, one must be set, but not both.
337
+ answer_style: Style in which the grounded answer should be returned.
338
+ safety_settings: Safety settings for generated output. Defaults to None.
339
+ temperature: Controls the randomness of the output.
340
+ client: If you're not relying on a default client, you pass a `glm.GenerativeServiceClient` instead.
341
+
342
+ Returns:
343
+ A `types.Answer` containing the model's text answer response.
344
+ """
345
+ if request_options is None:
346
+ request_options = {}
347
+
348
+ if client is None:
349
+ client = get_default_generative_async_client()
350
+
351
+ request = _make_generate_answer_request(
352
+ model=model,
353
+ contents=contents,
354
+ inline_passages=inline_passages,
355
+ semantic_retriever=semantic_retriever,
356
+ safety_settings=safety_settings,
357
+ temperature=temperature,
358
+ answer_style=answer_style,
359
+ )
360
+
361
+ response = await client.generate_answer(request, **request_options)
362
+
363
+ return response
.venv/lib/python3.11/site-packages/google/generativeai/audio_models/__init__.py ADDED
File without changes
.venv/lib/python3.11/site-packages/google/generativeai/audio_models/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (205 Bytes). View file
 
.venv/lib/python3.11/site-packages/google/generativeai/audio_models/__pycache__/_audio_models.cpython-311.pyc ADDED
Binary file (629 Bytes). View file
 
.venv/lib/python3.11/site-packages/google/generativeai/audio_models/_audio_models.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ class SpeechGenerationModel:
2
+ def __init__(self, model_name):
3
+ self.model_name = model_name
.venv/lib/python3.11/site-packages/google/generativeai/caching.py ADDED
@@ -0,0 +1,314 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright 2024 Google LLC
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ from __future__ import annotations
16
+
17
+ import datetime
18
+ import textwrap
19
+ from typing import Iterable, Optional
20
+
21
+ from google.generativeai import protos
22
+ from google.generativeai.types import caching_types
23
+ from google.generativeai.types import content_types
24
+ from google.generativeai.client import get_default_cache_client
25
+
26
+ from google.protobuf import field_mask_pb2
27
+
28
+ _USER_ROLE = "user"
29
+ _MODEL_ROLE = "model"
30
+
31
+
32
+ class CachedContent:
33
+ """Cached content resource."""
34
+
35
+ def __init__(self, name):
36
+ """Fetches a `CachedContent` resource.
37
+
38
+ Identical to `CachedContent.get`.
39
+
40
+ Args:
41
+ name: The resource name referring to the cached content.
42
+ """
43
+ client = get_default_cache_client()
44
+
45
+ if "cachedContents/" not in name:
46
+ name = "cachedContents/" + name
47
+
48
+ request = protos.GetCachedContentRequest(name=name)
49
+ response = client.get_cached_content(request)
50
+ self._proto = response
51
+
52
+ @property
53
+ def name(self) -> str:
54
+ return self._proto.name
55
+
56
+ @property
57
+ def model(self) -> str:
58
+ return self._proto.model
59
+
60
+ @property
61
+ def display_name(self) -> str:
62
+ return self._proto.display_name
63
+
64
+ @property
65
+ def usage_metadata(self) -> protos.CachedContent.UsageMetadata:
66
+ return self._proto.usage_metadata
67
+
68
+ @property
69
+ def create_time(self) -> datetime.datetime:
70
+ return self._proto.create_time
71
+
72
+ @property
73
+ def update_time(self) -> datetime.datetime:
74
+ return self._proto.update_time
75
+
76
+ @property
77
+ def expire_time(self) -> datetime.datetime:
78
+ return self._proto.expire_time
79
+
80
+ def __str__(self):
81
+ return textwrap.dedent(
82
+ f"""\
83
+ CachedContent(
84
+ name='{self.name}',
85
+ model='{self.model}',
86
+ display_name='{self.display_name}',
87
+ usage_metadata={'{'}
88
+ 'total_token_count': {self.usage_metadata.total_token_count},
89
+ {'}'},
90
+ create_time={self.create_time},
91
+ update_time={self.update_time},
92
+ expire_time={self.expire_time}
93
+ )"""
94
+ )
95
+
96
+ __repr__ = __str__
97
+
98
+ @classmethod
99
+ def _from_obj(cls, obj: CachedContent | protos.CachedContent | dict) -> CachedContent:
100
+ """Creates an instance of CachedContent form an object, without calling `get`."""
101
+ self = cls.__new__(cls)
102
+ self._proto = protos.CachedContent()
103
+ self._update(obj)
104
+ return self
105
+
106
+ def _update(self, updates):
107
+ """Updates this instance inplace, does not call the API's `update` method"""
108
+ if isinstance(updates, CachedContent):
109
+ updates = updates._proto
110
+
111
+ if not isinstance(updates, dict):
112
+ updates = type(updates).to_dict(updates, including_default_value_fields=False)
113
+
114
+ for key, value in updates.items():
115
+ setattr(self._proto, key, value)
116
+
117
+ @staticmethod
118
+ def _prepare_create_request(
119
+ model: str,
120
+ *,
121
+ display_name: str | None = None,
122
+ system_instruction: Optional[content_types.ContentType] = None,
123
+ contents: Optional[content_types.ContentsType] = None,
124
+ tools: Optional[content_types.FunctionLibraryType] = None,
125
+ tool_config: Optional[content_types.ToolConfigType] = None,
126
+ ttl: Optional[caching_types.TTLTypes] = None,
127
+ expire_time: Optional[caching_types.ExpireTimeTypes] = None,
128
+ ) -> protos.CreateCachedContentRequest:
129
+ """Prepares a CreateCachedContentRequest."""
130
+ if ttl and expire_time:
131
+ raise ValueError(
132
+ "Exclusive arguments: Please provide either `ttl` or `expire_time`, not both."
133
+ )
134
+
135
+ if "/" not in model:
136
+ model = "models/" + model
137
+
138
+ if display_name and len(display_name) > 128:
139
+ raise ValueError("`display_name` must be no more than 128 unicode characters.")
140
+
141
+ if system_instruction:
142
+ system_instruction = content_types.to_content(system_instruction)
143
+
144
+ tools_lib = content_types.to_function_library(tools)
145
+ if tools_lib:
146
+ tools_lib = tools_lib.to_proto()
147
+
148
+ if tool_config:
149
+ tool_config = content_types.to_tool_config(tool_config)
150
+
151
+ if contents:
152
+ contents = content_types.to_contents(contents)
153
+ if not contents[-1].role:
154
+ contents[-1].role = _USER_ROLE
155
+
156
+ ttl = caching_types.to_optional_ttl(ttl)
157
+ expire_time = caching_types.to_optional_expire_time(expire_time)
158
+
159
+ cached_content = protos.CachedContent(
160
+ model=model,
161
+ display_name=display_name,
162
+ system_instruction=system_instruction,
163
+ contents=contents,
164
+ tools=tools_lib,
165
+ tool_config=tool_config,
166
+ ttl=ttl,
167
+ expire_time=expire_time,
168
+ )
169
+
170
+ return protos.CreateCachedContentRequest(cached_content=cached_content)
171
+
172
+ @classmethod
173
+ def create(
174
+ cls,
175
+ model: str,
176
+ *,
177
+ display_name: str | None = None,
178
+ system_instruction: Optional[content_types.ContentType] = None,
179
+ contents: Optional[content_types.ContentsType] = None,
180
+ tools: Optional[content_types.FunctionLibraryType] = None,
181
+ tool_config: Optional[content_types.ToolConfigType] = None,
182
+ ttl: Optional[caching_types.TTLTypes] = None,
183
+ expire_time: Optional[caching_types.ExpireTimeTypes] = None,
184
+ ) -> CachedContent:
185
+ """Creates `CachedContent` resource.
186
+
187
+ Args:
188
+ model: The name of the `model` to use for cached content creation.
189
+ Any `CachedContent` resource can be only used with the
190
+ `model` it was created for.
191
+ display_name: The user-generated meaningful display name
192
+ of the cached content. `display_name` must be no
193
+ more than 128 unicode characters.
194
+ system_instruction: Developer set system instruction.
195
+ contents: Contents to cache.
196
+ tools: A list of `Tools` the model may use to generate response.
197
+ tool_config: Config to apply to all tools.
198
+ ttl: TTL for cached resource (in seconds). Defaults to 1 hour.
199
+ `ttl` and `expire_time` are exclusive arguments.
200
+ expire_time: Expiration time for cached resource.
201
+ `ttl` and `expire_time` are exclusive arguments.
202
+
203
+ Returns:
204
+ `CachedContent` resource with specified name.
205
+ """
206
+ client = get_default_cache_client()
207
+
208
+ request = cls._prepare_create_request(
209
+ model=model,
210
+ display_name=display_name,
211
+ system_instruction=system_instruction,
212
+ contents=contents,
213
+ tools=tools,
214
+ tool_config=tool_config,
215
+ ttl=ttl,
216
+ expire_time=expire_time,
217
+ )
218
+
219
+ response = client.create_cached_content(request)
220
+ result = CachedContent._from_obj(response)
221
+ return result
222
+
223
+ @classmethod
224
+ def get(cls, name: str) -> CachedContent:
225
+ """Fetches required `CachedContent` resource.
226
+
227
+ Args:
228
+ name: The resource name referring to the cached content.
229
+
230
+ Returns:
231
+ `CachedContent` resource with specified `name`.
232
+ """
233
+ client = get_default_cache_client()
234
+
235
+ if "cachedContents/" not in name:
236
+ name = "cachedContents/" + name
237
+
238
+ request = protos.GetCachedContentRequest(name=name)
239
+ response = client.get_cached_content(request)
240
+ result = CachedContent._from_obj(response)
241
+ return result
242
+
243
+ @classmethod
244
+ def list(cls, page_size: Optional[int] = 1) -> Iterable[CachedContent]:
245
+ """Lists `CachedContent` objects associated with the project.
246
+
247
+ Args:
248
+ page_size: The maximum number of permissions to return (per page).
249
+ The service may return fewer `CachedContent` objects.
250
+
251
+ Returns:
252
+ A paginated list of `CachedContent` objects.
253
+ """
254
+ client = get_default_cache_client()
255
+
256
+ request = protos.ListCachedContentsRequest(page_size=page_size)
257
+ for cached_content in client.list_cached_contents(request):
258
+ cached_content = CachedContent._from_obj(cached_content)
259
+ yield cached_content
260
+
261
+ def delete(self) -> None:
262
+ """Deletes `CachedContent` resource."""
263
+ client = get_default_cache_client()
264
+
265
+ request = protos.DeleteCachedContentRequest(name=self.name)
266
+ client.delete_cached_content(request)
267
+ return
268
+
269
+ def update(
270
+ self,
271
+ *,
272
+ ttl: Optional[caching_types.TTLTypes] = None,
273
+ expire_time: Optional[caching_types.ExpireTimeTypes] = None,
274
+ ) -> None:
275
+ """Updates requested `CachedContent` resource.
276
+
277
+ Args:
278
+ ttl: TTL for cached resource (in seconds). Defaults to 1 hour.
279
+ `ttl` and `expire_time` are exclusive arguments.
280
+ expire_time: Expiration time for cached resource.
281
+ `ttl` and `expire_time` are exclusive arguments.
282
+ """
283
+ client = get_default_cache_client()
284
+
285
+ if ttl and expire_time:
286
+ raise ValueError(
287
+ "Exclusive arguments: Please provide either `ttl` or `expire_time`, not both."
288
+ )
289
+
290
+ ttl = caching_types.to_optional_ttl(ttl)
291
+ expire_time = caching_types.to_optional_expire_time(expire_time)
292
+
293
+ updates = protos.CachedContent(
294
+ name=self.name,
295
+ ttl=ttl,
296
+ expire_time=expire_time,
297
+ )
298
+
299
+ field_mask = field_mask_pb2.FieldMask()
300
+
301
+ if ttl:
302
+ field_mask.paths.append("ttl")
303
+ elif expire_time:
304
+ field_mask.paths.append("expire_time")
305
+ else:
306
+ raise ValueError(
307
+ f"Bad update name: Only `ttl` or `expire_time` can be updated for `CachedContent`."
308
+ )
309
+
310
+ request = protos.UpdateCachedContentRequest(cached_content=updates, update_mask=field_mask)
311
+ updated_cc = client.update_cached_content(request)
312
+ self._update(updated_cc)
313
+
314
+ return
.venv/lib/python3.11/site-packages/google/generativeai/client.py ADDED
@@ -0,0 +1,388 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ import contextlib
5
+ import inspect
6
+ import dataclasses
7
+ import pathlib
8
+ import threading
9
+ from typing import Any, cast
10
+ from collections.abc import Sequence
11
+ import httplib2
12
+ from io import IOBase
13
+
14
+ import google.ai.generativelanguage as glm
15
+ import google.generativeai.protos as protos
16
+
17
+ from google.auth import credentials as ga_credentials
18
+ from google.auth import exceptions as ga_exceptions
19
+ from google import auth
20
+ from google.api_core import client_options as client_options_lib
21
+ from google.api_core import gapic_v1
22
+ from google.api_core import operations_v1
23
+
24
+ import googleapiclient.http
25
+ import googleapiclient.discovery
26
+
27
+ try:
28
+ from google.generativeai import version
29
+
30
+ __version__ = version.__version__
31
+ except ImportError:
32
+ __version__ = "0.0.0"
33
+
34
+ USER_AGENT = "genai-py"
35
+
36
+ #### Caution! ####
37
+ # - It would make sense for the discovery URL to respect the client_options.endpoint setting.
38
+ # - That would make testing Files on the staging server possible.
39
+ # - We tried fixing this once, but broke colab in the process because their endpoint didn't forward the discovery
40
+ # requests. https://github.com/google-gemini/generative-ai-python/pull/333
41
+ # - Kaggle would have a similar problem (b/362278209).
42
+ # - I think their proxy would forward the discovery traffic.
43
+ # - But they don't need to intercept the files-service at all, and uploads of large files could overload them.
44
+ # - Do the scotty uploads go to the same domain?
45
+ # - If you do route the discovery call to kaggle, be sure to attach the default_metadata (they need it).
46
+ # - One solution to all this would be if configure could take overrides per service.
47
+ # - set client_options.endpoint, but use a different endpoint for file service? It's not clear how best to do that
48
+ # through the file service.
49
+ ##################
50
+ GENAI_API_DISCOVERY_URL = "https://generativelanguage.googleapis.com/$discovery/rest"
51
+
52
+
53
+ @contextlib.contextmanager
54
+ def patch_colab_gce_credentials():
55
+ get_gce = auth._default._get_gce_credentials
56
+ if "COLAB_RELEASE_TAG" in os.environ:
57
+ auth._default._get_gce_credentials = lambda *args, **kwargs: (None, None)
58
+
59
+ try:
60
+ yield
61
+ finally:
62
+ auth._default._get_gce_credentials = get_gce
63
+
64
+
65
+ class FileServiceClient(glm.FileServiceClient):
66
+ def __init__(self, *args, **kwargs):
67
+ self._discovery_api = None
68
+ self._local = threading.local()
69
+ super().__init__(*args, **kwargs)
70
+
71
+ def _setup_discovery_api(self, metadata: dict | Sequence[tuple[str, str]] = ()):
72
+ api_key = self._client_options.api_key
73
+ if api_key is None:
74
+ raise ValueError(
75
+ "Invalid operation: Uploading to the File API requires an API key. Please provide a valid API key."
76
+ )
77
+
78
+ request = googleapiclient.http.HttpRequest(
79
+ http=httplib2.Http(),
80
+ postproc=lambda resp, content: (resp, content),
81
+ uri=f"{GENAI_API_DISCOVERY_URL}?version=v1beta&key={api_key}",
82
+ headers=dict(metadata),
83
+ )
84
+ response, content = request.execute()
85
+ request.http.close()
86
+
87
+ discovery_doc = content.decode("utf-8")
88
+ self._local.discovery_api = googleapiclient.discovery.build_from_document(
89
+ discovery_doc, developerKey=api_key
90
+ )
91
+
92
+ def create_file(
93
+ self,
94
+ path: str | pathlib.Path | os.PathLike | IOBase,
95
+ *,
96
+ mime_type: str | None = None,
97
+ name: str | None = None,
98
+ display_name: str | None = None,
99
+ resumable: bool = True,
100
+ metadata: Sequence[tuple[str, str]] = (),
101
+ ) -> protos.File:
102
+ if self._discovery_api is None:
103
+ self._setup_discovery_api(metadata)
104
+
105
+ file = {}
106
+ if name is not None:
107
+ file["name"] = name
108
+ if display_name is not None:
109
+ file["displayName"] = display_name
110
+
111
+ if isinstance(path, IOBase):
112
+ media = googleapiclient.http.MediaIoBaseUpload(
113
+ fd=path, mimetype=mime_type, resumable=resumable
114
+ )
115
+ else:
116
+ media = googleapiclient.http.MediaFileUpload(
117
+ filename=path, mimetype=mime_type, resumable=resumable
118
+ )
119
+
120
+ request = self._local.discovery_api.media().upload(body={"file": file}, media_body=media)
121
+ for key, value in metadata:
122
+ request.headers[key] = value
123
+ result = request.execute()
124
+
125
+ return self.get_file({"name": result["file"]["name"]})
126
+
127
+
128
+ class FileServiceAsyncClient(glm.FileServiceAsyncClient):
129
+ async def create_file(self, *args, **kwargs):
130
+ raise NotImplementedError(
131
+ "The `create_file` method is currently not supported for the asynchronous client."
132
+ )
133
+
134
+
135
+ @dataclasses.dataclass
136
+ class _ClientManager:
137
+ client_config: dict[str, Any] = dataclasses.field(default_factory=dict)
138
+ default_metadata: Sequence[tuple[str, str]] = ()
139
+ clients: dict[str, Any] = dataclasses.field(default_factory=dict)
140
+
141
+ def configure(
142
+ self,
143
+ *,
144
+ api_key: str | None = None,
145
+ credentials: ga_credentials.Credentials | dict | None = None,
146
+ # The user can pass a string to choose `rest` or `grpc` or 'grpc_asyncio'.
147
+ # See _transport_registry in the google.ai.generativelanguage package.
148
+ # Since the transport classes align with the client classes it wouldn't make
149
+ # sense to accept a `Transport` object here even though the client classes can.
150
+ # We could accept a dict since all the `Transport` classes take the same args,
151
+ # but that seems rare. Users that need it can just switch to the low level API.
152
+ transport: str | None = None,
153
+ client_options: client_options_lib.ClientOptions | dict[str, Any] | None = None,
154
+ client_info: gapic_v1.client_info.ClientInfo | None = None,
155
+ default_metadata: Sequence[tuple[str, str]] = (),
156
+ ) -> None:
157
+ """Initializes default client configurations using specified parameters or environment variables.
158
+
159
+ If no API key has been provided (either directly, or on `client_options`) and the
160
+ `GEMINI_API_KEY` environment variable is set, it will be used as the API key. If not,
161
+ if the `GOOGLE_API_KEY` environement variable is set, it will be used as the API key.
162
+
163
+ Note: Not all arguments are detailed below. Refer to the `*ServiceClient` classes in
164
+ `google.ai.generativelanguage` for details on the other arguments.
165
+
166
+ Args:
167
+ transport: A string, one of: [`rest`, `grpc`, `grpc_asyncio`].
168
+ api_key: The API-Key to use when creating the default clients (each service uses
169
+ a separate client). This is a shortcut for `client_options={"api_key": api_key}`.
170
+ If omitted, and the `GEMINI_API_KEY` or the `GOOGLE_API_KEY` environment variable
171
+ are set, they will be used in this order of priority.
172
+ default_metadata: Default (key, value) metadata pairs to send with every request.
173
+ when using `transport="rest"` these are sent as HTTP headers.
174
+ """
175
+ if isinstance(client_options, dict):
176
+ client_options = client_options_lib.from_dict(client_options)
177
+ if client_options is None:
178
+ client_options = client_options_lib.ClientOptions()
179
+ client_options = cast(client_options_lib.ClientOptions, client_options)
180
+ had_api_key_value = getattr(client_options, "api_key", None)
181
+
182
+ if had_api_key_value:
183
+ if api_key is not None:
184
+ raise ValueError(
185
+ "Invalid configuration: Please set either `api_key` or `client_options['api_key']`, but not both."
186
+ )
187
+ else:
188
+ if api_key is None:
189
+ # If no key is provided explicitly, attempt to load one from the
190
+ # environment.
191
+ api_key = os.getenv("GEMINI_API_KEY")
192
+
193
+ if api_key is None:
194
+ # If the GEMINI_API_KEY doesn't exist, attempt to load the
195
+ # GOOGLE_API_KEY from the environment.
196
+ api_key = os.getenv("GOOGLE_API_KEY")
197
+
198
+ client_options.api_key = api_key
199
+
200
+ user_agent = f"{USER_AGENT}/{__version__}"
201
+ if client_info:
202
+ # Be respectful of any existing agent setting.
203
+ if client_info.user_agent:
204
+ client_info.user_agent += f" {user_agent}"
205
+ else:
206
+ client_info.user_agent = user_agent
207
+ else:
208
+ client_info = gapic_v1.client_info.ClientInfo(user_agent=user_agent)
209
+
210
+ client_config = {
211
+ "credentials": credentials,
212
+ "transport": transport,
213
+ "client_options": client_options,
214
+ "client_info": client_info,
215
+ }
216
+
217
+ client_config = {key: value for key, value in client_config.items() if value is not None}
218
+
219
+ self.client_config = client_config
220
+ self.default_metadata = default_metadata
221
+
222
+ self.clients = {}
223
+
224
+ def make_client(self, name):
225
+ if name == "file":
226
+ cls = FileServiceClient
227
+ elif name == "file_async":
228
+ cls = FileServiceAsyncClient
229
+ elif name.endswith("_async"):
230
+ name = name.split("_")[0]
231
+ cls = getattr(glm, name.title() + "ServiceAsyncClient")
232
+ else:
233
+ cls = getattr(glm, name.title() + "ServiceClient")
234
+
235
+ # Attempt to configure using defaults.
236
+ if not self.client_config:
237
+ configure()
238
+
239
+ try:
240
+ with patch_colab_gce_credentials():
241
+ client = cls(**self.client_config)
242
+ except ga_exceptions.DefaultCredentialsError as e:
243
+ e.args = (
244
+ "\n No API_KEY or ADC found. Please either:\n"
245
+ " - Set the `GOOGLE_API_KEY` environment variable.\n"
246
+ " - Manually pass the key with `genai.configure(api_key=my_api_key)`.\n"
247
+ " - Or set up Application Default Credentials, see https://ai.google.dev/gemini-api/docs/oauth for more information.",
248
+ )
249
+ raise e
250
+
251
+ if not self.default_metadata:
252
+ return client
253
+
254
+ def keep(name, f):
255
+ if name.startswith("_"):
256
+ return False
257
+
258
+ if not callable(f):
259
+ return False
260
+
261
+ if "metadata" not in inspect.signature(f).parameters.keys():
262
+ return False
263
+
264
+ return True
265
+
266
+ def add_default_metadata_wrapper(f):
267
+ def call(*args, metadata=(), **kwargs):
268
+ metadata = list(metadata) + list(self.default_metadata)
269
+ return f(*args, **kwargs, metadata=metadata)
270
+
271
+ return call
272
+
273
+ for name, value in inspect.getmembers(cls):
274
+ if not keep(name, value):
275
+ continue
276
+ f = getattr(client, name)
277
+ f = add_default_metadata_wrapper(f)
278
+ setattr(client, name, f)
279
+
280
+ return client
281
+
282
+ def get_default_client(self, name):
283
+ name = name.lower()
284
+ if name == "operations":
285
+ return self.get_default_operations_client()
286
+
287
+ client = self.clients.get(name)
288
+ if client is None:
289
+ client = self.make_client(name)
290
+ self.clients[name] = client
291
+ return client
292
+
293
+ def get_default_operations_client(self) -> operations_v1.OperationsClient:
294
+ client = self.clients.get("operations", None)
295
+ if client is None:
296
+ model_client = self.get_default_client("Model")
297
+ client = model_client._transport.operations_client
298
+ self.clients["operations"] = client
299
+ return client
300
+
301
+
302
+ def configure(
303
+ *,
304
+ api_key: str | None = None,
305
+ credentials: ga_credentials.Credentials | dict | None = None,
306
+ # The user can pass a string to choose `rest` or `grpc` or 'grpc_asyncio'.
307
+ # Since the transport classes align with the client classes it wouldn't make
308
+ # sense to accept a `Transport` object here even though the client classes can.
309
+ # We could accept a dict since all the `Transport` classes take the same args,
310
+ # but that seems rare. Users that need it can just switch to the low level API.
311
+ transport: str | None = None,
312
+ client_options: client_options_lib.ClientOptions | dict | None = None,
313
+ client_info: gapic_v1.client_info.ClientInfo | None = None,
314
+ default_metadata: Sequence[tuple[str, str]] = (),
315
+ ):
316
+ """Captures default client configuration.
317
+
318
+ If no API key has been provided (either directly, or on `client_options`) and the
319
+ `GOOGLE_API_KEY` environment variable is set, it will be used as the API key.
320
+
321
+ Note: Not all arguments are detailed below. Refer to the `*ServiceClient` classes in
322
+ `google.ai.generativelanguage` for details on the other arguments.
323
+
324
+ Args:
325
+ transport: A string, one of: [`rest`, `grpc`, `grpc_asyncio`].
326
+ api_key: The API-Key to use when creating the default clients (each service uses
327
+ a separate client). This is a shortcut for `client_options={"api_key": api_key}`.
328
+ If omitted, and the `GOOGLE_API_KEY` environment variable is set, it will be
329
+ used.
330
+ default_metadata: Default (key, value) metadata pairs to send with every request.
331
+ when using `transport="rest"` these are sent as HTTP headers.
332
+ """
333
+ return _client_manager.configure(
334
+ api_key=api_key,
335
+ credentials=credentials,
336
+ transport=transport,
337
+ client_options=client_options,
338
+ client_info=client_info,
339
+ default_metadata=default_metadata,
340
+ )
341
+
342
+
343
+ _client_manager = _ClientManager()
344
+ _client_manager.configure()
345
+
346
+
347
+ def get_default_cache_client() -> glm.CacheServiceClient:
348
+ return _client_manager.get_default_client("cache")
349
+
350
+
351
+ def get_default_file_client() -> glm.FilesServiceClient:
352
+ return _client_manager.get_default_client("file")
353
+
354
+
355
+ def get_default_file_async_client() -> glm.FilesServiceAsyncClient:
356
+ return _client_manager.get_default_client("file_async")
357
+
358
+
359
+ def get_default_generative_client() -> glm.GenerativeServiceClient:
360
+ return _client_manager.get_default_client("generative")
361
+
362
+
363
+ def get_default_generative_async_client() -> glm.GenerativeServiceAsyncClient:
364
+ return _client_manager.get_default_client("generative_async")
365
+
366
+
367
+ def get_default_operations_client() -> operations_v1.OperationsClient:
368
+ return _client_manager.get_default_client("operations")
369
+
370
+
371
+ def get_default_model_client() -> glm.ModelServiceAsyncClient:
372
+ return _client_manager.get_default_client("model")
373
+
374
+
375
+ def get_default_retriever_client() -> glm.RetrieverClient:
376
+ return _client_manager.get_default_client("retriever")
377
+
378
+
379
+ def get_default_retriever_async_client() -> glm.RetrieverAsyncClient:
380
+ return _client_manager.get_default_client("retriever_async")
381
+
382
+
383
+ def get_default_permission_client() -> glm.PermissionServiceClient:
384
+ return _client_manager.get_default_client("permission")
385
+
386
+
387
+ def get_default_permission_async_client() -> glm.PermissionServiceAsyncClient:
388
+ return _client_manager.get_default_client("permission_async")
.venv/lib/python3.11/site-packages/google/generativeai/discuss.py ADDED
@@ -0,0 +1,590 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright 2023 Google LLC
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ from __future__ import annotations
16
+
17
+ import dataclasses
18
+ import sys
19
+ import textwrap
20
+
21
+ from typing import Any, Iterable, List, Optional, Union
22
+
23
+ import google.ai.generativelanguage as glm
24
+
25
+ from google.generativeai.client import get_default_discuss_client
26
+ from google.generativeai.client import get_default_discuss_async_client
27
+ from google.generativeai import string_utils
28
+ from google.generativeai.types import discuss_types
29
+ from google.generativeai.types import model_types
30
+ from google.generativeai.types import safety_types
31
+
32
+
33
+ def _make_message(content: discuss_types.MessageOptions) -> glm.Message:
34
+ """Creates a `glm.Message` object from the provided content."""
35
+ if isinstance(content, glm.Message):
36
+ return content
37
+ if isinstance(content, str):
38
+ return glm.Message(content=content)
39
+ else:
40
+ return glm.Message(content)
41
+
42
+
43
+ def _make_messages(
44
+ messages: discuss_types.MessagesOptions,
45
+ ) -> List[glm.Message]:
46
+ """
47
+ Creates a list of `glm.Message` objects from the provided messages.
48
+
49
+ This function takes a variety of message content inputs, such as strings, dictionaries,
50
+ or `glm.Message` objects, and creates a list of `glm.Message` objects. It ensures that
51
+ the authors of the messages alternate appropriately. If authors are not provided,
52
+ default authors are assigned based on their position in the list.
53
+
54
+ Args:
55
+ messages: The messages to convert.
56
+
57
+ Returns:
58
+ A list of `glm.Message` objects with alternating authors.
59
+ """
60
+ if isinstance(messages, (str, dict, glm.Message)):
61
+ messages = [_make_message(messages)]
62
+ else:
63
+ messages = [_make_message(message) for message in messages]
64
+
65
+ even_authors = set(msg.author for msg in messages[::2] if msg.author)
66
+ if not even_authors:
67
+ even_author = "0"
68
+ elif len(even_authors) == 1:
69
+ even_author = even_authors.pop()
70
+ else:
71
+ raise discuss_types.AuthorError("Authors are not strictly alternating")
72
+
73
+ odd_authors = set(msg.author for msg in messages[1::2] if msg.author)
74
+ if not odd_authors:
75
+ odd_author = "1"
76
+ elif len(odd_authors) == 1:
77
+ odd_author = odd_authors.pop()
78
+ else:
79
+ raise discuss_types.AuthorError("Authors are not strictly alternating")
80
+
81
+ if all(msg.author for msg in messages):
82
+ return messages
83
+
84
+ authors = [even_author, odd_author]
85
+ for i, msg in enumerate(messages):
86
+ msg.author = authors[i % 2]
87
+
88
+ return messages
89
+
90
+
91
+ def _make_example(item: discuss_types.ExampleOptions) -> glm.Example:
92
+ """Creates a `glm.Example` object from the provided item."""
93
+ if isinstance(item, glm.Example):
94
+ return item
95
+
96
+ if isinstance(item, dict):
97
+ item = item.copy()
98
+ item["input"] = _make_message(item["input"])
99
+ item["output"] = _make_message(item["output"])
100
+ return glm.Example(item)
101
+
102
+ if isinstance(item, Iterable):
103
+ input, output = list(item)
104
+ return glm.Example(input=_make_message(input), output=_make_message(output))
105
+
106
+ # try anyway
107
+ return glm.Example(item)
108
+
109
+
110
+ def _make_examples_from_flat(
111
+ examples: List[discuss_types.MessageOptions],
112
+ ) -> List[glm.Example]:
113
+ """
114
+ Creates a list of `glm.Example` objects from a list of message options.
115
+
116
+ This function takes a list of `discuss_types.MessageOptions` and pairs them into
117
+ `glm.Example` objects. The input examples must be in pairs to create valid examples.
118
+
119
+ Args:
120
+ examples: The list of `discuss_types.MessageOptions`.
121
+
122
+ Returns:
123
+ A list of `glm.Example objects` created by pairing up the provided messages.
124
+
125
+ Raises:
126
+ ValueError: If the provided list of examples is not of even length.
127
+ """
128
+ if len(examples) % 2 != 0:
129
+ raise ValueError(
130
+ textwrap.dedent(
131
+ f"""\
132
+ You must pass `Primer` objects, pairs of messages, or an *even* number of messages, got:
133
+ {len(examples)} messages"""
134
+ )
135
+ )
136
+ result = []
137
+ pair = []
138
+ for n, item in enumerate(examples):
139
+ msg = _make_message(item)
140
+ pair.append(msg)
141
+ if n % 2 == 0:
142
+ continue
143
+ primer = glm.Example(
144
+ input=pair[0],
145
+ output=pair[1],
146
+ )
147
+ result.append(primer)
148
+ pair = []
149
+ return result
150
+
151
+
152
+ def _make_examples(
153
+ examples: discuss_types.ExamplesOptions,
154
+ ) -> List[glm.Example]:
155
+ """
156
+ Creates a list of `glm.Example` objects from the provided examples.
157
+
158
+ This function takes various types of example content inputs and creates a list
159
+ of `glm.Example` objects. It handles the conversion of different input types and ensures
160
+ the appropriate structure for creating valid examples.
161
+
162
+ Args:
163
+ examples: The examples to convert.
164
+
165
+ Returns:
166
+ A list of `glm.Example` objects created from the provided examples.
167
+ """
168
+ if isinstance(examples, glm.Example):
169
+ return [examples]
170
+
171
+ if isinstance(examples, dict):
172
+ return [_make_example(examples)]
173
+
174
+ examples = list(examples)
175
+
176
+ if not examples:
177
+ return examples
178
+
179
+ first = examples[0]
180
+
181
+ if isinstance(first, dict):
182
+ if "content" in first:
183
+ # These are `Messages`
184
+ return _make_examples_from_flat(examples)
185
+ else:
186
+ if not ("input" in first and "output" in first):
187
+ raise TypeError(
188
+ "To create an `Example` from a dict you must supply both `input` and an `output` keys"
189
+ )
190
+ else:
191
+ if isinstance(first, discuss_types.MESSAGE_OPTIONS):
192
+ return _make_examples_from_flat(examples)
193
+
194
+ result = []
195
+ for item in examples:
196
+ result.append(_make_example(item))
197
+ return result
198
+
199
+
200
+ def _make_message_prompt_dict(
201
+ prompt: discuss_types.MessagePromptOptions = None,
202
+ *,
203
+ context: str | None = None,
204
+ examples: discuss_types.ExamplesOptions | None = None,
205
+ messages: discuss_types.MessagesOptions | None = None,
206
+ ) -> glm.MessagePrompt:
207
+ """
208
+ Creates a `glm.MessagePrompt` object from the provided prompt components.
209
+
210
+ This function constructs a `glm.MessagePrompt` object using the provided `context`, `examples`,
211
+ or `messages`. It ensures the proper structure and handling of the input components.
212
+
213
+ Either pass a `prompt` or it's component `context`, `examples`, `messages`.
214
+
215
+ Args:
216
+ prompt: The complete prompt components.
217
+ context: The context for the prompt.
218
+ examples: The examples for the prompt.
219
+ messages: The messages for the prompt.
220
+
221
+ Returns:
222
+ A `glm.MessagePrompt` object created from the provided prompt components.
223
+ """
224
+ if prompt is None:
225
+ prompt = dict(
226
+ context=context,
227
+ examples=examples,
228
+ messages=messages,
229
+ )
230
+ else:
231
+ flat_prompt = (context is not None) or (examples is not None) or (messages is not None)
232
+ if flat_prompt:
233
+ raise ValueError(
234
+ "You can't set `prompt`, and its fields `(context, examples, messages)`"
235
+ " at the same time"
236
+ )
237
+ if isinstance(prompt, glm.MessagePrompt):
238
+ return prompt
239
+ elif isinstance(prompt, dict): # Always check dict before Iterable.
240
+ pass
241
+ else:
242
+ prompt = {"messages": prompt}
243
+
244
+ keys = set(prompt.keys())
245
+ if not keys.issubset(discuss_types.MESSAGE_PROMPT_KEYS):
246
+ raise KeyError(
247
+ f"Found extra entries in the prompt dictionary: {keys - discuss_types.MESSAGE_PROMPT_KEYS}"
248
+ )
249
+
250
+ examples = prompt.get("examples", None)
251
+ if examples is not None:
252
+ prompt["examples"] = _make_examples(examples)
253
+ messages = prompt.get("messages", None)
254
+ if messages is not None:
255
+ prompt["messages"] = _make_messages(messages)
256
+
257
+ prompt = {k: v for k, v in prompt.items() if v is not None}
258
+ return prompt
259
+
260
+
261
+ def _make_message_prompt(
262
+ prompt: discuss_types.MessagePromptOptions = None,
263
+ *,
264
+ context: str | None = None,
265
+ examples: discuss_types.ExamplesOptions | None = None,
266
+ messages: discuss_types.MessagesOptions | None = None,
267
+ ) -> glm.MessagePrompt:
268
+ """Creates a `glm.MessagePrompt` object from the provided prompt components."""
269
+ prompt = _make_message_prompt_dict(
270
+ prompt=prompt, context=context, examples=examples, messages=messages
271
+ )
272
+ return glm.MessagePrompt(prompt)
273
+
274
+
275
+ def _make_generate_message_request(
276
+ *,
277
+ model: model_types.AnyModelNameOptions | None,
278
+ context: str | None = None,
279
+ examples: discuss_types.ExamplesOptions | None = None,
280
+ messages: discuss_types.MessagesOptions | None = None,
281
+ temperature: float | None = None,
282
+ candidate_count: int | None = None,
283
+ top_p: float | None = None,
284
+ top_k: float | None = None,
285
+ prompt: discuss_types.MessagePromptOptions | None = None,
286
+ ) -> glm.GenerateMessageRequest:
287
+ """Creates a `glm.GenerateMessageRequest` object for generating messages."""
288
+ model = model_types.make_model_name(model)
289
+
290
+ prompt = _make_message_prompt(
291
+ prompt=prompt, context=context, examples=examples, messages=messages
292
+ )
293
+
294
+ return glm.GenerateMessageRequest(
295
+ model=model,
296
+ prompt=prompt,
297
+ temperature=temperature,
298
+ top_p=top_p,
299
+ top_k=top_k,
300
+ candidate_count=candidate_count,
301
+ )
302
+
303
+
304
+ DEFAULT_DISCUSS_MODEL = "models/chat-bison-001"
305
+
306
+
307
+ def chat(
308
+ *,
309
+ model: model_types.AnyModelNameOptions | None = "models/chat-bison-001",
310
+ context: str | None = None,
311
+ examples: discuss_types.ExamplesOptions | None = None,
312
+ messages: discuss_types.MessagesOptions | None = None,
313
+ temperature: float | None = None,
314
+ candidate_count: int | None = None,
315
+ top_p: float | None = None,
316
+ top_k: float | None = None,
317
+ prompt: discuss_types.MessagePromptOptions | None = None,
318
+ client: glm.DiscussServiceClient | None = None,
319
+ request_options: dict[str, Any] | None = None,
320
+ ) -> discuss_types.ChatResponse:
321
+ """Calls the API and returns a `types.ChatResponse` containing the response.
322
+
323
+ Args:
324
+ model: Which model to call, as a string or a `types.Model`.
325
+ context: Text that should be provided to the model first, to ground the response.
326
+
327
+ If not empty, this `context` will be given to the model first before the
328
+ `examples` and `messages`.
329
+
330
+ This field can be a description of your prompt to the model to help provide
331
+ context and guide the responses.
332
+
333
+ Examples:
334
+
335
+ * "Translate the phrase from English to French."
336
+ * "Given a statement, classify the sentiment as happy, sad or neutral."
337
+
338
+ Anything included in this field will take precedence over history in `messages`
339
+ if the total input size exceeds the model's `Model.input_token_limit`.
340
+ examples: Examples of what the model should generate.
341
+
342
+ This includes both the user input and the response that the model should
343
+ emulate.
344
+
345
+ These `examples` are treated identically to conversation messages except
346
+ that they take precedence over the history in `messages`:
347
+ If the total input size exceeds the model's `input_token_limit` the input
348
+ will be truncated. Items will be dropped from `messages` before `examples`
349
+ messages: A snapshot of the conversation history sorted chronologically.
350
+
351
+ Turns alternate between two authors.
352
+
353
+ If the total input size exceeds the model's `input_token_limit` the input
354
+ will be truncated: The oldest items will be dropped from `messages`.
355
+ temperature: Controls the randomness of the output. Must be positive.
356
+
357
+ Typical values are in the range: `[0.0,1.0]`. Higher values produce a
358
+ more random and varied response. A temperature of zero will be deterministic.
359
+ candidate_count: The **maximum** number of generated response messages to return.
360
+
361
+ This value must be between `[1, 8]`, inclusive. If unset, this
362
+ will default to `1`.
363
+
364
+ Note: Only unique candidates are returned. Higher temperatures are more
365
+ likely to produce unique candidates. Setting `temperature=0.0` will always
366
+ return 1 candidate regardless of the `candidate_count`.
367
+ top_k: The API uses combined [nucleus](https://arxiv.org/abs/1904.09751) and
368
+ top-k sampling.
369
+
370
+ `top_k` sets the maximum number of tokens to sample from on each step.
371
+ top_p: The API uses combined [nucleus](https://arxiv.org/abs/1904.09751) and
372
+ top-k sampling.
373
+
374
+ `top_p` configures the nucleus sampling. It sets the maximum cumulative
375
+ probability of tokens to sample from.
376
+
377
+ For example, if the sorted probabilities are
378
+ `[0.5, 0.2, 0.1, 0.1, 0.05, 0.05]` a `top_p` of `0.8` will sample
379
+ as `[0.625, 0.25, 0.125, 0, 0, 0]`.
380
+
381
+ Typical values are in the `[0.9, 1.0]` range.
382
+ prompt: You may pass a `types.MessagePromptOptions` **instead** of a
383
+ setting `context`/`examples`/`messages`, but not both.
384
+ client: If you're not relying on the default client, you pass a
385
+ `glm.DiscussServiceClient` instead.
386
+ request_options: Options for the request.
387
+
388
+ Returns:
389
+ A `types.ChatResponse` containing the model's reply.
390
+ """
391
+ request = _make_generate_message_request(
392
+ model=model,
393
+ context=context,
394
+ examples=examples,
395
+ messages=messages,
396
+ temperature=temperature,
397
+ candidate_count=candidate_count,
398
+ top_p=top_p,
399
+ top_k=top_k,
400
+ prompt=prompt,
401
+ )
402
+
403
+ return _generate_response(client=client, request=request, request_options=request_options)
404
+
405
+
406
+ @string_utils.set_doc(chat.__doc__)
407
+ async def chat_async(
408
+ *,
409
+ model: model_types.AnyModelNameOptions | None = "models/chat-bison-001",
410
+ context: str | None = None,
411
+ examples: discuss_types.ExamplesOptions | None = None,
412
+ messages: discuss_types.MessagesOptions | None = None,
413
+ temperature: float | None = None,
414
+ candidate_count: int | None = None,
415
+ top_p: float | None = None,
416
+ top_k: float | None = None,
417
+ prompt: discuss_types.MessagePromptOptions | None = None,
418
+ client: glm.DiscussServiceAsyncClient | None = None,
419
+ request_options: dict[str, Any] | None = None,
420
+ ) -> discuss_types.ChatResponse:
421
+ request = _make_generate_message_request(
422
+ model=model,
423
+ context=context,
424
+ examples=examples,
425
+ messages=messages,
426
+ temperature=temperature,
427
+ candidate_count=candidate_count,
428
+ top_p=top_p,
429
+ top_k=top_k,
430
+ prompt=prompt,
431
+ )
432
+
433
+ return await _generate_response_async(
434
+ client=client, request=request, request_options=request_options
435
+ )
436
+
437
+
438
+ if (sys.version_info.major, sys.version_info.minor) >= (3, 10):
439
+ DATACLASS_KWARGS = {"kw_only": True}
440
+ else:
441
+ DATACLASS_KWARGS = {}
442
+
443
+
444
+ @string_utils.prettyprint
445
+ @string_utils.set_doc(discuss_types.ChatResponse.__doc__)
446
+ @dataclasses.dataclass(**DATACLASS_KWARGS, init=False)
447
+ class ChatResponse(discuss_types.ChatResponse):
448
+ _client: glm.DiscussServiceClient | None = dataclasses.field(default=lambda: None, repr=False)
449
+
450
+ def __init__(self, **kwargs):
451
+ for key, value in kwargs.items():
452
+ setattr(self, key, value)
453
+
454
+ @property
455
+ @string_utils.set_doc(discuss_types.ChatResponse.last.__doc__)
456
+ def last(self) -> str | None:
457
+ if self.messages[-1]:
458
+ return self.messages[-1]["content"]
459
+ else:
460
+ return None
461
+
462
+ @last.setter
463
+ def last(self, message: discuss_types.MessageOptions):
464
+ message = _make_message(message)
465
+ message = type(message).to_dict(message)
466
+ self.messages[-1] = message
467
+
468
+ @string_utils.set_doc(discuss_types.ChatResponse.reply.__doc__)
469
+ def reply(
470
+ self,
471
+ message: discuss_types.MessageOptions,
472
+ request_options: dict[str, Any] | None = None,
473
+ ) -> discuss_types.ChatResponse:
474
+ if isinstance(self._client, glm.DiscussServiceAsyncClient):
475
+ raise TypeError(f"reply can't be called on an async client, use reply_async instead.")
476
+ if self.last is None:
477
+ raise ValueError(
478
+ "The last response from the model did not return any candidates.\n"
479
+ "Check the `.filters` attribute to see why the responses were filtered:\n"
480
+ f"{self.filters}"
481
+ )
482
+
483
+ request = self.to_dict()
484
+ request.pop("candidates")
485
+ request.pop("filters", None)
486
+ request["messages"] = list(request["messages"])
487
+ request["messages"].append(_make_message(message))
488
+ request = _make_generate_message_request(**request)
489
+ return _generate_response(
490
+ request=request, client=self._client, request_options=request_options
491
+ )
492
+
493
+ @string_utils.set_doc(discuss_types.ChatResponse.reply.__doc__)
494
+ async def reply_async(
495
+ self, message: discuss_types.MessageOptions
496
+ ) -> discuss_types.ChatResponse:
497
+ if isinstance(self._client, glm.DiscussServiceClient):
498
+ raise TypeError(
499
+ f"reply_async can't be called on a non-async client, use reply instead."
500
+ )
501
+ request = self.to_dict()
502
+ request.pop("candidates")
503
+ request.pop("filters", None)
504
+ request["messages"] = list(request["messages"])
505
+ request["messages"].append(_make_message(message))
506
+ request = _make_generate_message_request(**request)
507
+ return await _generate_response_async(request=request, client=self._client)
508
+
509
+
510
+ def _build_chat_response(
511
+ request: glm.GenerateMessageRequest,
512
+ response: glm.GenerateMessageResponse,
513
+ client: glm.DiscussServiceClient | glm.DiscussServiceAsyncClient,
514
+ ) -> ChatResponse:
515
+ request = type(request).to_dict(request)
516
+ prompt = request.pop("prompt")
517
+ request["examples"] = prompt["examples"]
518
+ request["context"] = prompt["context"]
519
+ request["messages"] = prompt["messages"]
520
+
521
+ response = type(response).to_dict(response)
522
+ response.pop("messages")
523
+
524
+ response["filters"] = safety_types.convert_filters_to_enums(response["filters"])
525
+
526
+ if response["candidates"]:
527
+ last = response["candidates"][0]
528
+ else:
529
+ last = None
530
+ request["messages"].append(last)
531
+ request.setdefault("temperature", None)
532
+ request.setdefault("candidate_count", None)
533
+
534
+ return ChatResponse(_client=client, **response, **request) # pytype: disable=missing-parameter
535
+
536
+
537
+ def _generate_response(
538
+ request: glm.GenerateMessageRequest,
539
+ client: glm.DiscussServiceClient | None = None,
540
+ request_options: dict[str, Any] | None = None,
541
+ ) -> ChatResponse:
542
+ if request_options is None:
543
+ request_options = {}
544
+
545
+ if client is None:
546
+ client = get_default_discuss_client()
547
+
548
+ response = client.generate_message(request, **request_options)
549
+
550
+ return _build_chat_response(request, response, client)
551
+
552
+
553
+ async def _generate_response_async(
554
+ request: glm.GenerateMessageRequest,
555
+ client: glm.DiscussServiceAsyncClient | None = None,
556
+ request_options: dict[str, Any] | None = None,
557
+ ) -> ChatResponse:
558
+ if request_options is None:
559
+ request_options = {}
560
+
561
+ if client is None:
562
+ client = get_default_discuss_async_client()
563
+
564
+ response = await client.generate_message(request, **request_options)
565
+
566
+ return _build_chat_response(request, response, client)
567
+
568
+
569
+ def count_message_tokens(
570
+ *,
571
+ prompt: discuss_types.MessagePromptOptions = None,
572
+ context: str | None = None,
573
+ examples: discuss_types.ExamplesOptions | None = None,
574
+ messages: discuss_types.MessagesOptions | None = None,
575
+ model: model_types.AnyModelNameOptions = DEFAULT_DISCUSS_MODEL,
576
+ client: glm.DiscussServiceAsyncClient | None = None,
577
+ request_options: dict[str, Any] | None = None,
578
+ ) -> discuss_types.TokenCount:
579
+ model = model_types.make_model_name(model)
580
+ prompt = _make_message_prompt(prompt, context=context, examples=examples, messages=messages)
581
+
582
+ if request_options is None:
583
+ request_options = {}
584
+
585
+ if client is None:
586
+ client = get_default_discuss_client()
587
+
588
+ result = client.count_message_tokens(model=model, prompt=prompt, **request_options)
589
+
590
+ return type(result).to_dict(result)
.venv/lib/python3.11/site-packages/google/generativeai/embedding.py ADDED
@@ -0,0 +1,312 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright 2023 Google LLC
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ from __future__ import annotations
16
+
17
+ import itertools
18
+ from typing import Any, Iterable, overload, TypeVar, Union, Mapping
19
+
20
+ import google.ai.generativelanguage as glm
21
+ from google.generativeai import protos
22
+
23
+ from google.generativeai.client import get_default_generative_client
24
+ from google.generativeai.client import get_default_generative_async_client
25
+
26
+ from google.generativeai.types import helper_types
27
+ from google.generativeai.types import model_types
28
+ from google.generativeai.types import text_types
29
+ from google.generativeai.types import content_types
30
+
31
+ DEFAULT_EMB_MODEL = "models/embedding-001"
32
+ EMBEDDING_MAX_BATCH_SIZE = 100
33
+
34
+ EmbeddingTaskType = protos.TaskType
35
+
36
+ EmbeddingTaskTypeOptions = Union[int, str, EmbeddingTaskType]
37
+
38
+ _EMBEDDING_TASK_TYPE: dict[EmbeddingTaskTypeOptions, EmbeddingTaskType] = {
39
+ EmbeddingTaskType.TASK_TYPE_UNSPECIFIED: EmbeddingTaskType.TASK_TYPE_UNSPECIFIED,
40
+ 0: EmbeddingTaskType.TASK_TYPE_UNSPECIFIED,
41
+ "task_type_unspecified": EmbeddingTaskType.TASK_TYPE_UNSPECIFIED,
42
+ "unspecified": EmbeddingTaskType.TASK_TYPE_UNSPECIFIED,
43
+ EmbeddingTaskType.RETRIEVAL_QUERY: EmbeddingTaskType.RETRIEVAL_QUERY,
44
+ 1: EmbeddingTaskType.RETRIEVAL_QUERY,
45
+ "retrieval_query": EmbeddingTaskType.RETRIEVAL_QUERY,
46
+ "query": EmbeddingTaskType.RETRIEVAL_QUERY,
47
+ EmbeddingTaskType.RETRIEVAL_DOCUMENT: EmbeddingTaskType.RETRIEVAL_DOCUMENT,
48
+ 2: EmbeddingTaskType.RETRIEVAL_DOCUMENT,
49
+ "retrieval_document": EmbeddingTaskType.RETRIEVAL_DOCUMENT,
50
+ "document": EmbeddingTaskType.RETRIEVAL_DOCUMENT,
51
+ EmbeddingTaskType.SEMANTIC_SIMILARITY: EmbeddingTaskType.SEMANTIC_SIMILARITY,
52
+ 3: EmbeddingTaskType.SEMANTIC_SIMILARITY,
53
+ "semantic_similarity": EmbeddingTaskType.SEMANTIC_SIMILARITY,
54
+ "similarity": EmbeddingTaskType.SEMANTIC_SIMILARITY,
55
+ EmbeddingTaskType.CLASSIFICATION: EmbeddingTaskType.CLASSIFICATION,
56
+ 4: EmbeddingTaskType.CLASSIFICATION,
57
+ "classification": EmbeddingTaskType.CLASSIFICATION,
58
+ EmbeddingTaskType.CLUSTERING: EmbeddingTaskType.CLUSTERING,
59
+ 5: EmbeddingTaskType.CLUSTERING,
60
+ "clustering": EmbeddingTaskType.CLUSTERING,
61
+ 6: EmbeddingTaskType.QUESTION_ANSWERING,
62
+ "question_answering": EmbeddingTaskType.QUESTION_ANSWERING,
63
+ "qa": EmbeddingTaskType.QUESTION_ANSWERING,
64
+ EmbeddingTaskType.QUESTION_ANSWERING: EmbeddingTaskType.QUESTION_ANSWERING,
65
+ 7: EmbeddingTaskType.FACT_VERIFICATION,
66
+ "fact_verification": EmbeddingTaskType.FACT_VERIFICATION,
67
+ "verification": EmbeddingTaskType.FACT_VERIFICATION,
68
+ EmbeddingTaskType.FACT_VERIFICATION: EmbeddingTaskType.FACT_VERIFICATION,
69
+ }
70
+
71
+
72
+ def to_task_type(x: EmbeddingTaskTypeOptions) -> EmbeddingTaskType:
73
+ if isinstance(x, str):
74
+ x = x.lower()
75
+ return _EMBEDDING_TASK_TYPE[x]
76
+
77
+
78
+ try:
79
+ # python 3.12+
80
+ _batched = itertools.batched # type: ignore
81
+ except AttributeError:
82
+ T = TypeVar("T")
83
+
84
+ def _batched(iterable: Iterable[T], n: int) -> Iterable[list[T]]:
85
+ if n < 1:
86
+ raise ValueError(
87
+ f"Invalid input: The batch size 'n' must be a positive integer. You entered: {n}. Please enter a number greater than 0."
88
+ )
89
+ batch = []
90
+ for item in iterable:
91
+ batch.append(item)
92
+ if len(batch) == n:
93
+ yield batch
94
+ batch = []
95
+
96
+ if batch:
97
+ yield batch
98
+
99
+
100
+ @overload
101
+ def embed_content(
102
+ model: model_types.BaseModelNameOptions,
103
+ content: content_types.ContentType,
104
+ task_type: EmbeddingTaskTypeOptions | None = None,
105
+ title: str | None = None,
106
+ output_dimensionality: int | None = None,
107
+ client: glm.GenerativeServiceClient | None = None,
108
+ request_options: helper_types.RequestOptionsType | None = None,
109
+ ) -> text_types.EmbeddingDict: ...
110
+
111
+
112
+ @overload
113
+ def embed_content(
114
+ model: model_types.BaseModelNameOptions,
115
+ content: Iterable[content_types.ContentType],
116
+ task_type: EmbeddingTaskTypeOptions | None = None,
117
+ title: str | None = None,
118
+ output_dimensionality: int | None = None,
119
+ client: glm.GenerativeServiceClient | None = None,
120
+ request_options: helper_types.RequestOptionsType | None = None,
121
+ ) -> text_types.BatchEmbeddingDict: ...
122
+
123
+
124
+ def embed_content(
125
+ model: model_types.BaseModelNameOptions,
126
+ content: content_types.ContentType | Iterable[content_types.ContentType],
127
+ task_type: EmbeddingTaskTypeOptions | None = None,
128
+ title: str | None = None,
129
+ output_dimensionality: int | None = None,
130
+ client: glm.GenerativeServiceClient = None,
131
+ request_options: helper_types.RequestOptionsType | None = None,
132
+ ) -> text_types.EmbeddingDict | text_types.BatchEmbeddingDict:
133
+ """Calls the API to create embeddings for content passed in.
134
+
135
+ Args:
136
+ model:
137
+ Which [model](https://ai.google.dev/models/gemini#embedding) to
138
+ call, as a string or a `types.Model`.
139
+
140
+ content:
141
+ Content to embed.
142
+
143
+ task_type:
144
+ Optional task type for which the embeddings will be used. Can only
145
+ be set for `models/embedding-001`.
146
+
147
+ title:
148
+ An optional title for the text. Only applicable when task_type is
149
+ `RETRIEVAL_DOCUMENT`.
150
+
151
+ output_dimensionality:
152
+ Optional reduced dimensionality for the output embeddings. If set,
153
+ excessive values from the output embeddings will be truncated from
154
+ the end.
155
+
156
+ request_options:
157
+ Options for the request.
158
+
159
+ Return:
160
+ Dictionary containing the embedding (list of float values) for the
161
+ input content.
162
+ """
163
+ model = model_types.make_model_name(model)
164
+
165
+ if request_options is None:
166
+ request_options = {}
167
+
168
+ if client is None:
169
+ client = get_default_generative_client()
170
+
171
+ if title and to_task_type(task_type) is not EmbeddingTaskType.RETRIEVAL_DOCUMENT:
172
+ raise ValueError(
173
+ f"Invalid task type: When a title is specified, the task must be of a 'retrieval document' type. Received task type: {task_type} and title: {title}."
174
+ )
175
+
176
+ if output_dimensionality and output_dimensionality < 0:
177
+ raise ValueError(
178
+ f"Invalid value: `output_dimensionality` must be a non-negative integer. Received: {output_dimensionality}."
179
+ )
180
+
181
+ if task_type:
182
+ task_type = to_task_type(task_type)
183
+
184
+ if isinstance(content, Iterable) and not isinstance(content, (str, Mapping)):
185
+ result = {"embedding": []}
186
+ requests = (
187
+ protos.EmbedContentRequest(
188
+ model=model,
189
+ content=content_types.to_content(c),
190
+ task_type=task_type,
191
+ title=title,
192
+ output_dimensionality=output_dimensionality,
193
+ )
194
+ for c in content
195
+ )
196
+ for batch in _batched(requests, EMBEDDING_MAX_BATCH_SIZE):
197
+ embedding_request = protos.BatchEmbedContentsRequest(model=model, requests=batch)
198
+ embedding_response = client.batch_embed_contents(
199
+ embedding_request,
200
+ **request_options,
201
+ )
202
+ embedding_dict = type(embedding_response).to_dict(embedding_response)
203
+ result["embedding"].extend(e["values"] for e in embedding_dict["embeddings"])
204
+ return result
205
+ else:
206
+ embedding_request = protos.EmbedContentRequest(
207
+ model=model,
208
+ content=content_types.to_content(content),
209
+ task_type=task_type,
210
+ title=title,
211
+ output_dimensionality=output_dimensionality,
212
+ )
213
+ embedding_response = client.embed_content(
214
+ embedding_request,
215
+ **request_options,
216
+ )
217
+ embedding_dict = type(embedding_response).to_dict(embedding_response)
218
+ embedding_dict["embedding"] = embedding_dict["embedding"]["values"]
219
+ return embedding_dict
220
+
221
+
222
+ @overload
223
+ async def embed_content_async(
224
+ model: model_types.BaseModelNameOptions,
225
+ content: content_types.ContentType,
226
+ task_type: EmbeddingTaskTypeOptions | None = None,
227
+ title: str | None = None,
228
+ output_dimensionality: int | None = None,
229
+ client: glm.GenerativeServiceAsyncClient | None = None,
230
+ request_options: helper_types.RequestOptionsType | None = None,
231
+ ) -> text_types.EmbeddingDict: ...
232
+
233
+
234
+ @overload
235
+ async def embed_content_async(
236
+ model: model_types.BaseModelNameOptions,
237
+ content: Iterable[content_types.ContentType],
238
+ task_type: EmbeddingTaskTypeOptions | None = None,
239
+ title: str | None = None,
240
+ output_dimensionality: int | None = None,
241
+ client: glm.GenerativeServiceAsyncClient | None = None,
242
+ request_options: helper_types.RequestOptionsType | None = None,
243
+ ) -> text_types.BatchEmbeddingDict: ...
244
+
245
+
246
+ async def embed_content_async(
247
+ model: model_types.BaseModelNameOptions,
248
+ content: content_types.ContentType | Iterable[content_types.ContentType],
249
+ task_type: EmbeddingTaskTypeOptions | None = None,
250
+ title: str | None = None,
251
+ output_dimensionality: int | None = None,
252
+ client: glm.GenerativeServiceAsyncClient = None,
253
+ request_options: helper_types.RequestOptionsType | None = None,
254
+ ) -> text_types.EmbeddingDict | text_types.BatchEmbeddingDict:
255
+ """Calls the API to create async embeddings for content passed in."""
256
+
257
+ model = model_types.make_model_name(model)
258
+
259
+ if request_options is None:
260
+ request_options = {}
261
+
262
+ if client is None:
263
+ client = get_default_generative_async_client()
264
+
265
+ if title and to_task_type(task_type) is not EmbeddingTaskType.RETRIEVAL_DOCUMENT:
266
+ raise ValueError(
267
+ f"Invalid task type: When a title is specified, the task must be of a 'retrieval document' type. Received task type: {task_type} and title: {title}."
268
+ )
269
+ if output_dimensionality and output_dimensionality < 0:
270
+ raise ValueError(
271
+ f"Invalid value: `output_dimensionality` must be a non-negative integer. Received: {output_dimensionality}."
272
+ )
273
+
274
+ if task_type:
275
+ task_type = to_task_type(task_type)
276
+
277
+ if isinstance(content, Iterable) and not isinstance(content, (str, Mapping)):
278
+ result = {"embedding": []}
279
+ requests = (
280
+ protos.EmbedContentRequest(
281
+ model=model,
282
+ content=content_types.to_content(c),
283
+ task_type=task_type,
284
+ title=title,
285
+ output_dimensionality=output_dimensionality,
286
+ )
287
+ for c in content
288
+ )
289
+ for batch in _batched(requests, EMBEDDING_MAX_BATCH_SIZE):
290
+ embedding_request = protos.BatchEmbedContentsRequest(model=model, requests=batch)
291
+ embedding_response = await client.batch_embed_contents(
292
+ embedding_request,
293
+ **request_options,
294
+ )
295
+ embedding_dict = type(embedding_response).to_dict(embedding_response)
296
+ result["embedding"].extend(e["values"] for e in embedding_dict["embeddings"])
297
+ return result
298
+ else:
299
+ embedding_request = protos.EmbedContentRequest(
300
+ model=model,
301
+ content=content_types.to_content(content),
302
+ task_type=task_type,
303
+ title=title,
304
+ output_dimensionality=output_dimensionality,
305
+ )
306
+ embedding_response = await client.embed_content(
307
+ embedding_request,
308
+ **request_options,
309
+ )
310
+ embedding_dict = type(embedding_response).to_dict(embedding_response)
311
+ embedding_dict["embedding"] = embedding_dict["embedding"]["values"]
312
+ return embedding_dict
.venv/lib/python3.11/site-packages/google/generativeai/files.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright 2023 Google LLC
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ from __future__ import annotations
16
+
17
+ import os
18
+ import pathlib
19
+ import mimetypes
20
+ from typing import Iterable
21
+ import logging
22
+ from google.generativeai import protos
23
+ from itertools import islice
24
+ from io import IOBase
25
+
26
+ from google.generativeai.types import file_types
27
+
28
+ from google.generativeai.client import get_default_file_client
29
+
30
+ __all__ = ["upload_file", "get_file", "list_files", "delete_file"]
31
+
32
+ mimetypes.add_type("image/webp", ".webp")
33
+
34
+
35
+ def upload_file(
36
+ path: str | pathlib.Path | os.PathLike | IOBase,
37
+ *,
38
+ mime_type: str | None = None,
39
+ name: str | None = None,
40
+ display_name: str | None = None,
41
+ resumable: bool = True,
42
+ ) -> file_types.File:
43
+ """Calls the API to upload a file using a supported file service.
44
+
45
+ Args:
46
+ path: The path to the file or a file-like object (e.g., BytesIO) to be uploaded.
47
+ mime_type: The MIME type of the file. If not provided, it will be
48
+ inferred from the file extension.
49
+ name: The name of the file in the destination (e.g., 'files/sample-image').
50
+ If not provided, a system generated ID will be created.
51
+ display_name: Optional display name of the file.
52
+ resumable: Whether to use the resumable upload protocol. By default, this is enabled.
53
+ See details at
54
+ https://googleapis.github.io/google-api-python-client/docs/epy/googleapiclient.http.MediaFileUpload-class.html#resumable
55
+
56
+ Returns:
57
+ file_types.File: The response of the uploaded file.
58
+ """
59
+ client = get_default_file_client()
60
+
61
+ if isinstance(path, IOBase):
62
+ if mime_type is None:
63
+ raise ValueError(
64
+ "Unknown mime type: When passing a file like object to `path` (instead of a\n"
65
+ " path-like object) you must set the `mime_type` argument"
66
+ )
67
+ else:
68
+ path = pathlib.Path(os.fspath(path))
69
+
70
+ if display_name is None:
71
+ display_name = path.name
72
+
73
+ if mime_type is None:
74
+ mime_type, _ = mimetypes.guess_type(path)
75
+
76
+ if mime_type is None:
77
+ raise ValueError(
78
+ "Unknown mime type: Could not determine the mimetype for your file\n"
79
+ " please set the `mime_type` argument"
80
+ )
81
+
82
+ if name is not None and "/" not in name:
83
+ name = f"files/{name}"
84
+
85
+ response = client.create_file(
86
+ path=path, mime_type=mime_type, name=name, display_name=display_name, resumable=resumable
87
+ )
88
+ return file_types.File(response)
89
+
90
+
91
+ def list_files(page_size=100) -> Iterable[file_types.File]:
92
+ """Calls the API to list files using a supported file service."""
93
+ client = get_default_file_client()
94
+
95
+ response = client.list_files(protos.ListFilesRequest(page_size=page_size))
96
+ for proto in response:
97
+ yield file_types.File(proto)
98
+
99
+
100
+ def get_file(name: str) -> file_types.File:
101
+ """Calls the API to retrieve a specified file using a supported file service."""
102
+ if "/" not in name:
103
+ name = f"files/{name}"
104
+ client = get_default_file_client()
105
+ return file_types.File(client.get_file(name=name))
106
+
107
+
108
+ def delete_file(name: str | file_types.File | protos.File):
109
+ """Calls the API to permanently delete a specified file using a supported file service."""
110
+ if isinstance(name, (file_types.File, protos.File)):
111
+ name = name.name
112
+ elif "/" not in name:
113
+ name = f"files/{name}"
114
+ request = protos.DeleteFileRequest(name=name)
115
+ client = get_default_file_client()
116
+ client.delete_file(request=request)
.venv/lib/python3.11/site-packages/google/generativeai/generative_models.py ADDED
@@ -0,0 +1,875 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Classes for working with the Gemini models."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from collections.abc import Iterable
6
+ import textwrap
7
+ from typing import Any, Union, overload
8
+ import reprlib
9
+
10
+ # pylint: disable=bad-continuation, line-too-long
11
+
12
+
13
+ import google.api_core.exceptions
14
+ from google.generativeai import protos
15
+ from google.generativeai import client
16
+
17
+ from google.generativeai import caching
18
+ from google.generativeai.types import content_types
19
+ from google.generativeai.types import generation_types
20
+ from google.generativeai.types import helper_types
21
+ from google.generativeai.types import safety_types
22
+
23
+ _USER_ROLE = "user"
24
+ _MODEL_ROLE = "model"
25
+
26
+
27
+ class GenerativeModel:
28
+ """
29
+ The `genai.GenerativeModel` class wraps default parameters for calls to
30
+ `GenerativeModel.generate_content`, `GenerativeModel.count_tokens`, and
31
+ `GenerativeModel.start_chat`.
32
+
33
+ This family of functionality is designed to support multi-turn conversations, and multimodal
34
+ requests. What media-types are supported for input and output is model-dependant.
35
+
36
+ >>> import google.generativeai as genai
37
+ >>> import PIL.Image
38
+ >>> genai.configure(api_key='YOUR_API_KEY')
39
+ >>> model = genai.GenerativeModel('models/gemini-1.5-flash')
40
+ >>> result = model.generate_content('Tell me a story about a magic backpack')
41
+ >>> result.text
42
+ "In the quaint little town of Lakeside, there lived a young girl named Lily..."
43
+
44
+ Multimodal input:
45
+
46
+ >>> model = genai.GenerativeModel('models/gemini-1.5-flash')
47
+ >>> result = model.generate_content([
48
+ ... "Give me a recipe for these:", PIL.Image.open('scones.jpeg')])
49
+ >>> result.text
50
+ "**Blueberry Scones** ..."
51
+
52
+ Multi-turn conversation:
53
+
54
+ >>> chat = model.start_chat()
55
+ >>> response = chat.send_message("Hi, I have some questions for you.")
56
+ >>> response.text
57
+ "Sure, I'll do my best to answer your questions..."
58
+
59
+ To list the compatible model names use:
60
+
61
+ >>> for m in genai.list_models():
62
+ ... if 'generateContent' in m.supported_generation_methods:
63
+ ... print(m.name)
64
+
65
+ Arguments:
66
+ model_name: The name of the model to query. To list compatible models use
67
+ safety_settings: Sets the default safety filters. This controls which content is blocked
68
+ by the api before being returned.
69
+ generation_config: A `genai.GenerationConfig` setting the default generation parameters to
70
+ use.
71
+ """
72
+
73
+ def __init__(
74
+ self,
75
+ model_name: str = "gemini-1.5-flash-002",
76
+ safety_settings: safety_types.SafetySettingOptions | None = None,
77
+ generation_config: generation_types.GenerationConfigType | None = None,
78
+ tools: content_types.FunctionLibraryType | None = None,
79
+ tool_config: content_types.ToolConfigType | None = None,
80
+ system_instruction: content_types.ContentType | None = None,
81
+ ):
82
+ if "/" not in model_name:
83
+ model_name = "models/" + model_name
84
+ self._model_name = model_name
85
+ self._safety_settings = safety_types.to_easy_safety_dict(safety_settings)
86
+ self._generation_config = generation_types.to_generation_config_dict(generation_config)
87
+ self._tools = content_types.to_function_library(tools)
88
+
89
+ if tool_config is None:
90
+ self._tool_config = None
91
+ else:
92
+ self._tool_config = content_types.to_tool_config(tool_config)
93
+
94
+ if system_instruction is None:
95
+ self._system_instruction = None
96
+ else:
97
+ self._system_instruction = content_types.to_content(system_instruction)
98
+
99
+ self._client = None
100
+ self._async_client = None
101
+
102
+ @property
103
+ def cached_content(self) -> str:
104
+ return getattr(self, "_cached_content", None)
105
+
106
+ @property
107
+ def model_name(self):
108
+ return self._model_name
109
+
110
+ def __str__(self):
111
+ def maybe_text(content):
112
+ if content and len(content.parts) and (t := content.parts[0].text):
113
+ return repr(t)
114
+ return content
115
+
116
+ return textwrap.dedent(
117
+ f"""\
118
+ genai.GenerativeModel(
119
+ model_name='{self.model_name}',
120
+ generation_config={self._generation_config},
121
+ safety_settings={self._safety_settings},
122
+ tools={self._tools},
123
+ system_instruction={maybe_text(self._system_instruction)},
124
+ cached_content={self.cached_content}
125
+ )"""
126
+ )
127
+
128
+ __repr__ = __str__
129
+
130
+ def _prepare_request(
131
+ self,
132
+ *,
133
+ contents: content_types.ContentsType,
134
+ generation_config: generation_types.GenerationConfigType | None = None,
135
+ safety_settings: safety_types.SafetySettingOptions | None = None,
136
+ tools: content_types.FunctionLibraryType | None,
137
+ tool_config: content_types.ToolConfigType | None,
138
+ ) -> protos.GenerateContentRequest:
139
+ """Creates a `protos.GenerateContentRequest` from raw inputs."""
140
+ if hasattr(self, "_cached_content") and any([self._system_instruction, tools, tool_config]):
141
+ raise ValueError(
142
+ "`tools`, `tool_config`, `system_instruction` cannot be set on a model instantiated with `cached_content` as its context."
143
+ )
144
+
145
+ tools_lib = self._get_tools_lib(tools)
146
+ if tools_lib is not None:
147
+ tools_lib = tools_lib.to_proto()
148
+
149
+ if tool_config is None:
150
+ tool_config = self._tool_config
151
+ else:
152
+ tool_config = content_types.to_tool_config(tool_config)
153
+
154
+ contents = content_types.to_contents(contents)
155
+
156
+ generation_config = generation_types.to_generation_config_dict(generation_config)
157
+ merged_gc = self._generation_config.copy()
158
+ merged_gc.update(generation_config)
159
+
160
+ safety_settings = safety_types.to_easy_safety_dict(safety_settings)
161
+ merged_ss = self._safety_settings.copy()
162
+ merged_ss.update(safety_settings)
163
+ merged_ss = safety_types.normalize_safety_settings(merged_ss)
164
+
165
+ return protos.GenerateContentRequest(
166
+ model=self._model_name,
167
+ contents=contents,
168
+ generation_config=merged_gc,
169
+ safety_settings=merged_ss,
170
+ tools=tools_lib,
171
+ tool_config=tool_config,
172
+ system_instruction=self._system_instruction,
173
+ cached_content=self.cached_content,
174
+ )
175
+
176
+ def _get_tools_lib(
177
+ self, tools: content_types.FunctionLibraryType
178
+ ) -> content_types.FunctionLibrary | None:
179
+ if tools is None:
180
+ return self._tools
181
+ else:
182
+ return content_types.to_function_library(tools)
183
+
184
+ @overload
185
+ @classmethod
186
+ def from_cached_content(
187
+ cls,
188
+ cached_content: str,
189
+ *,
190
+ generation_config: generation_types.GenerationConfigType | None = None,
191
+ safety_settings: safety_types.SafetySettingOptions | None = None,
192
+ ) -> GenerativeModel: ...
193
+
194
+ @overload
195
+ @classmethod
196
+ def from_cached_content(
197
+ cls,
198
+ cached_content: caching.CachedContent,
199
+ *,
200
+ generation_config: generation_types.GenerationConfigType | None = None,
201
+ safety_settings: safety_types.SafetySettingOptions | None = None,
202
+ ) -> GenerativeModel: ...
203
+
204
+ @classmethod
205
+ def from_cached_content(
206
+ cls,
207
+ cached_content: str | caching.CachedContent,
208
+ *,
209
+ generation_config: generation_types.GenerationConfigType | None = None,
210
+ safety_settings: safety_types.SafetySettingOptions | None = None,
211
+ ) -> GenerativeModel:
212
+ """Creates a model with `cached_content` as model's context.
213
+
214
+ Args:
215
+ cached_content: context for the model.
216
+ generation_config: Overrides for the model's generation config.
217
+ safety_settings: Overrides for the model's safety settings.
218
+
219
+ Returns:
220
+ `GenerativeModel` object with `cached_content` as its context.
221
+ """
222
+ if isinstance(cached_content, str):
223
+ cached_content = caching.CachedContent.get(name=cached_content)
224
+
225
+ # call __init__ to set the model's `generation_config`, `safety_settings`.
226
+ # `model_name` will be the name of the model for which the `cached_content` was created.
227
+ self = cls(
228
+ model_name=cached_content.model,
229
+ generation_config=generation_config,
230
+ safety_settings=safety_settings,
231
+ )
232
+
233
+ # set the model's context.
234
+ setattr(self, "_cached_content", cached_content.name)
235
+ return self
236
+
237
+ def generate_content(
238
+ self,
239
+ contents: content_types.ContentsType,
240
+ *,
241
+ generation_config: generation_types.GenerationConfigType | None = None,
242
+ safety_settings: safety_types.SafetySettingOptions | None = None,
243
+ stream: bool = False,
244
+ tools: content_types.FunctionLibraryType | None = None,
245
+ tool_config: content_types.ToolConfigType | None = None,
246
+ request_options: helper_types.RequestOptionsType | None = None,
247
+ ) -> generation_types.GenerateContentResponse:
248
+ """A multipurpose function to generate responses from the model.
249
+
250
+ This `GenerativeModel.generate_content` method can handle multimodal input, and multi-turn
251
+ conversations.
252
+
253
+ >>> model = genai.GenerativeModel('models/gemini-1.5-flash')
254
+ >>> response = model.generate_content('Tell me a story about a magic backpack')
255
+ >>> response.text
256
+
257
+ ### Streaming
258
+
259
+ This method supports streaming with the `stream=True`. The result has the same type as the non streaming case,
260
+ but you can iterate over the response chunks as they become available:
261
+
262
+ >>> response = model.generate_content('Tell me a story about a magic backpack', stream=True)
263
+ >>> for chunk in response:
264
+ ... print(chunk.text)
265
+
266
+ ### Multi-turn
267
+
268
+ This method supports multi-turn chats but is **stateless**: the entire conversation history needs to be sent with each
269
+ request. This takes some manual management but gives you complete control:
270
+
271
+ >>> messages = [{'role':'user', 'parts': ['hello']}]
272
+ >>> response = model.generate_content(messages) # "Hello, how can I help"
273
+ >>> messages.append(response.candidates[0].content)
274
+ >>> messages.append({'role':'user', 'parts': ['How does quantum physics work?']})
275
+ >>> response = model.generate_content(messages)
276
+
277
+ For a simpler multi-turn interface see `GenerativeModel.start_chat`.
278
+
279
+ ### Input type flexibility
280
+
281
+ While the underlying API strictly expects a `list[protos.Content]` objects, this method
282
+ will convert the user input into the correct type. The hierarchy of types that can be
283
+ converted is below. Any of these objects can be passed as an equivalent `dict`.
284
+
285
+ * `Iterable[protos.Content]`
286
+ * `protos.Content`
287
+ * `Iterable[protos.Part]`
288
+ * `protos.Part`
289
+ * `str`, `Image`, or `protos.Blob`
290
+
291
+ In an `Iterable[protos.Content]` each `content` is a separate message.
292
+ But note that an `Iterable[protos.Part]` is taken as the parts of a single message.
293
+
294
+ Arguments:
295
+ contents: The contents serving as the model's prompt.
296
+ generation_config: Overrides for the model's generation config.
297
+ safety_settings: Overrides for the model's safety settings.
298
+ stream: If True, yield response chunks as they are generated.
299
+ tools: `protos.Tools` more info coming soon.
300
+ request_options: Options for the request.
301
+ """
302
+ if not contents:
303
+ raise TypeError("contents must not be empty")
304
+
305
+ request = self._prepare_request(
306
+ contents=contents,
307
+ generation_config=generation_config,
308
+ safety_settings=safety_settings,
309
+ tools=tools,
310
+ tool_config=tool_config,
311
+ )
312
+
313
+ if request.contents and not request.contents[-1].role:
314
+ request.contents[-1].role = _USER_ROLE
315
+
316
+ if self._client is None:
317
+ self._client = client.get_default_generative_client()
318
+
319
+ if request_options is None:
320
+ request_options = {}
321
+
322
+ try:
323
+ if stream:
324
+ with generation_types.rewrite_stream_error():
325
+ iterator = self._client.stream_generate_content(
326
+ request,
327
+ **request_options,
328
+ )
329
+ return generation_types.GenerateContentResponse.from_iterator(iterator)
330
+ else:
331
+ response = self._client.generate_content(
332
+ request,
333
+ **request_options,
334
+ )
335
+ return generation_types.GenerateContentResponse.from_response(response)
336
+ except google.api_core.exceptions.InvalidArgument as e:
337
+ if e.message.startswith("Request payload size exceeds the limit:"):
338
+ e.message += (
339
+ " The file size is too large. Please use the File API to upload your files instead. "
340
+ "Example: `f = genai.upload_file(path); m.generate_content(['tell me about this file:', f])`"
341
+ )
342
+ raise
343
+
344
+ async def generate_content_async(
345
+ self,
346
+ contents: content_types.ContentsType,
347
+ *,
348
+ generation_config: generation_types.GenerationConfigType | None = None,
349
+ safety_settings: safety_types.SafetySettingOptions | None = None,
350
+ stream: bool = False,
351
+ tools: content_types.FunctionLibraryType | None = None,
352
+ tool_config: content_types.ToolConfigType | None = None,
353
+ request_options: helper_types.RequestOptionsType | None = None,
354
+ ) -> generation_types.AsyncGenerateContentResponse:
355
+ """The async version of `GenerativeModel.generate_content`."""
356
+ if not contents:
357
+ raise TypeError("contents must not be empty")
358
+
359
+ request = self._prepare_request(
360
+ contents=contents,
361
+ generation_config=generation_config,
362
+ safety_settings=safety_settings,
363
+ tools=tools,
364
+ tool_config=tool_config,
365
+ )
366
+
367
+ if request.contents and not request.contents[-1].role:
368
+ request.contents[-1].role = _USER_ROLE
369
+
370
+ if self._async_client is None:
371
+ self._async_client = client.get_default_generative_async_client()
372
+
373
+ if request_options is None:
374
+ request_options = {}
375
+
376
+ try:
377
+ if stream:
378
+ with generation_types.rewrite_stream_error():
379
+ iterator = await self._async_client.stream_generate_content(
380
+ request,
381
+ **request_options,
382
+ )
383
+ return await generation_types.AsyncGenerateContentResponse.from_aiterator(iterator)
384
+ else:
385
+ response = await self._async_client.generate_content(
386
+ request,
387
+ **request_options,
388
+ )
389
+ return generation_types.AsyncGenerateContentResponse.from_response(response)
390
+ except google.api_core.exceptions.InvalidArgument as e:
391
+ if e.message.startswith("Request payload size exceeds the limit:"):
392
+ e.message += (
393
+ " The file size is too large. Please use the File API to upload your files instead. "
394
+ "Example: `f = genai.upload_file(path); m.generate_content(['tell me about this file:', f])`"
395
+ )
396
+ raise
397
+
398
+ # fmt: off
399
+ def count_tokens(
400
+ self,
401
+ contents: content_types.ContentsType = None,
402
+ *,
403
+ generation_config: generation_types.GenerationConfigType | None = None,
404
+ safety_settings: safety_types.SafetySettingOptions | None = None,
405
+ tools: content_types.FunctionLibraryType | None = None,
406
+ tool_config: content_types.ToolConfigType | None = None,
407
+ request_options: helper_types.RequestOptionsType | None = None,
408
+ ) -> protos.CountTokensResponse:
409
+ if request_options is None:
410
+ request_options = {}
411
+
412
+ if self._client is None:
413
+ self._client = client.get_default_generative_client()
414
+
415
+ request = protos.CountTokensRequest(
416
+ model=self.model_name,
417
+ generate_content_request=self._prepare_request(
418
+ contents=contents,
419
+ generation_config=generation_config,
420
+ safety_settings=safety_settings,
421
+ tools=tools,
422
+ tool_config=tool_config,
423
+ ))
424
+ return self._client.count_tokens(request, **request_options)
425
+
426
+ async def count_tokens_async(
427
+ self,
428
+ contents: content_types.ContentsType = None,
429
+ *,
430
+ generation_config: generation_types.GenerationConfigType | None = None,
431
+ safety_settings: safety_types.SafetySettingOptions | None = None,
432
+ tools: content_types.FunctionLibraryType | None = None,
433
+ tool_config: content_types.ToolConfigType | None = None,
434
+ request_options: helper_types.RequestOptionsType | None = None,
435
+ ) -> protos.CountTokensResponse:
436
+ if request_options is None:
437
+ request_options = {}
438
+
439
+ if self._async_client is None:
440
+ self._async_client = client.get_default_generative_async_client()
441
+
442
+ request = protos.CountTokensRequest(
443
+ model=self.model_name,
444
+ generate_content_request=self._prepare_request(
445
+ contents=contents,
446
+ generation_config=generation_config,
447
+ safety_settings=safety_settings,
448
+ tools=tools,
449
+ tool_config=tool_config,
450
+ ))
451
+ return await self._async_client.count_tokens(request, **request_options)
452
+
453
+ # fmt: on
454
+
455
+ def start_chat(
456
+ self,
457
+ *,
458
+ history: Iterable[content_types.StrictContentType] | None = None,
459
+ enable_automatic_function_calling: bool = False,
460
+ ) -> ChatSession:
461
+ """Returns a `genai.ChatSession` attached to this model.
462
+
463
+ >>> model = genai.GenerativeModel()
464
+ >>> chat = model.start_chat(history=[...])
465
+ >>> response = chat.send_message("Hello?")
466
+
467
+ Arguments:
468
+ history: An iterable of `protos.Content` objects, or equivalents to initialize the session.
469
+ """
470
+ if self._generation_config.get("candidate_count", 1) > 1:
471
+ raise ValueError(
472
+ "Invalid configuration: The chat functionality does not support `candidate_count` greater than 1."
473
+ )
474
+ return ChatSession(
475
+ model=self,
476
+ history=history,
477
+ enable_automatic_function_calling=enable_automatic_function_calling,
478
+ )
479
+
480
+
481
+ class ChatSession:
482
+ """Contains an ongoing conversation with the model.
483
+
484
+ >>> model = genai.GenerativeModel('models/gemini-1.5-flash')
485
+ >>> chat = model.start_chat()
486
+ >>> response = chat.send_message("Hello")
487
+ >>> print(response.text)
488
+ >>> response = chat.send_message("Hello again")
489
+ >>> print(response.text)
490
+ >>> response = chat.send_message(...
491
+
492
+ This `ChatSession` object collects the messages sent and received, in its
493
+ `ChatSession.history` attribute.
494
+
495
+ Arguments:
496
+ model: The model to use in the chat.
497
+ history: A chat history to initialize the object with.
498
+ """
499
+
500
+ def __init__(
501
+ self,
502
+ model: GenerativeModel,
503
+ history: Iterable[content_types.StrictContentType] | None = None,
504
+ enable_automatic_function_calling: bool = False,
505
+ ):
506
+ self.model: GenerativeModel = model
507
+ self._history: list[protos.Content] = content_types.to_contents(history)
508
+ self._last_sent: protos.Content | None = None
509
+ self._last_received: generation_types.BaseGenerateContentResponse | None = None
510
+ self.enable_automatic_function_calling = enable_automatic_function_calling
511
+
512
+ def send_message(
513
+ self,
514
+ content: content_types.ContentType,
515
+ *,
516
+ generation_config: generation_types.GenerationConfigType = None,
517
+ safety_settings: safety_types.SafetySettingOptions = None,
518
+ stream: bool = False,
519
+ tools: content_types.FunctionLibraryType | None = None,
520
+ tool_config: content_types.ToolConfigType | None = None,
521
+ request_options: helper_types.RequestOptionsType | None = None,
522
+ ) -> generation_types.GenerateContentResponse:
523
+ """Sends the conversation history with the added message and returns the model's response.
524
+
525
+ Appends the request and response to the conversation history.
526
+
527
+ >>> model = genai.GenerativeModel('models/gemini-1.5-flash')
528
+ >>> chat = model.start_chat()
529
+ >>> response = chat.send_message("Hello")
530
+ >>> print(response.text)
531
+ "Hello! How can I assist you today?"
532
+ >>> len(chat.history)
533
+ 2
534
+
535
+ Call it with `stream=True` to receive response chunks as they are generated:
536
+
537
+ >>> chat = model.start_chat()
538
+ >>> response = chat.send_message("Explain quantum physics", stream=True)
539
+ >>> for chunk in response:
540
+ ... print(chunk.text, end='')
541
+
542
+ Once iteration over chunks is complete, the `response` and `ChatSession` are in states identical to the
543
+ `stream=False` case. Some properties are not available until iteration is complete.
544
+
545
+ Like `GenerativeModel.generate_content` this method lets you override the model's `generation_config` and
546
+ `safety_settings`.
547
+
548
+ Arguments:
549
+ content: The message contents.
550
+ generation_config: Overrides for the model's generation config.
551
+ safety_settings: Overrides for the model's safety settings.
552
+ stream: If True, yield response chunks as they are generated.
553
+ """
554
+ if request_options is None:
555
+ request_options = {}
556
+
557
+ if self.enable_automatic_function_calling and stream:
558
+ raise NotImplementedError(
559
+ "Unsupported configuration: The `google.generativeai` SDK currently does not support the combination of `stream=True` and `enable_automatic_function_calling=True`."
560
+ )
561
+
562
+ tools_lib = self.model._get_tools_lib(tools)
563
+
564
+ content = content_types.to_content(content)
565
+
566
+ if not content.role:
567
+ content.role = _USER_ROLE
568
+
569
+ history = self.history[:]
570
+ history.append(content)
571
+
572
+ generation_config = generation_types.to_generation_config_dict(generation_config)
573
+ if generation_config.get("candidate_count", 1) > 1:
574
+ raise ValueError(
575
+ "Invalid configuration: The chat functionality does not support `candidate_count` greater than 1."
576
+ )
577
+
578
+ response = self.model.generate_content(
579
+ contents=history,
580
+ generation_config=generation_config,
581
+ safety_settings=safety_settings,
582
+ stream=stream,
583
+ tools=tools_lib,
584
+ tool_config=tool_config,
585
+ request_options=request_options,
586
+ )
587
+
588
+ self._check_response(response=response, stream=stream)
589
+
590
+ if self.enable_automatic_function_calling and tools_lib is not None:
591
+ self.history, content, response = self._handle_afc(
592
+ response=response,
593
+ history=history,
594
+ generation_config=generation_config,
595
+ safety_settings=safety_settings,
596
+ stream=stream,
597
+ tools_lib=tools_lib,
598
+ request_options=request_options,
599
+ )
600
+
601
+ self._last_sent = content
602
+ self._last_received = response
603
+
604
+ return response
605
+
606
+ def _check_response(self, *, response, stream):
607
+ if response.prompt_feedback.block_reason:
608
+ raise generation_types.BlockedPromptException(response.prompt_feedback)
609
+
610
+ if not stream:
611
+ if response.candidates[0].finish_reason not in (
612
+ protos.Candidate.FinishReason.FINISH_REASON_UNSPECIFIED,
613
+ protos.Candidate.FinishReason.STOP,
614
+ protos.Candidate.FinishReason.MAX_TOKENS,
615
+ ):
616
+ raise generation_types.StopCandidateException(response.candidates[0])
617
+
618
+ def _get_function_calls(self, response) -> list[protos.FunctionCall]:
619
+ candidates = response.candidates
620
+ if len(candidates) != 1:
621
+ raise ValueError(
622
+ f"Invalid number of candidates: Automatic function calling only works with 1 candidate, but {len(candidates)} were provided."
623
+ )
624
+ parts = candidates[0].content.parts
625
+ function_calls = [part.function_call for part in parts if part and "function_call" in part]
626
+ return function_calls
627
+
628
+ def _handle_afc(
629
+ self,
630
+ *,
631
+ response,
632
+ history,
633
+ generation_config,
634
+ safety_settings,
635
+ stream,
636
+ tools_lib,
637
+ request_options,
638
+ ) -> tuple[list[protos.Content], protos.Content, generation_types.BaseGenerateContentResponse]:
639
+
640
+ while function_calls := self._get_function_calls(response):
641
+ if not all(callable(tools_lib[fc]) for fc in function_calls):
642
+ break
643
+ history.append(response.candidates[0].content)
644
+
645
+ function_response_parts: list[protos.Part] = []
646
+ for fc in function_calls:
647
+ fr = tools_lib(fc)
648
+ assert fr is not None, (
649
+ "Unexpected state: The function reference (fr) should never be None. It should only return None if the declaration "
650
+ "is not callable, which is checked earlier in the code."
651
+ )
652
+ function_response_parts.append(fr)
653
+
654
+ send = protos.Content(role=_USER_ROLE, parts=function_response_parts)
655
+ history.append(send)
656
+
657
+ response = self.model.generate_content(
658
+ contents=history,
659
+ generation_config=generation_config,
660
+ safety_settings=safety_settings,
661
+ stream=stream,
662
+ tools=tools_lib,
663
+ request_options=request_options,
664
+ )
665
+
666
+ self._check_response(response=response, stream=stream)
667
+
668
+ *history, content = history
669
+ return history, content, response
670
+
671
+ async def send_message_async(
672
+ self,
673
+ content: content_types.ContentType,
674
+ *,
675
+ generation_config: generation_types.GenerationConfigType = None,
676
+ safety_settings: safety_types.SafetySettingOptions = None,
677
+ stream: bool = False,
678
+ tools: content_types.FunctionLibraryType | None = None,
679
+ tool_config: content_types.ToolConfigType | None = None,
680
+ request_options: helper_types.RequestOptionsType | None = None,
681
+ ) -> generation_types.AsyncGenerateContentResponse:
682
+ """The async version of `ChatSession.send_message`."""
683
+ if request_options is None:
684
+ request_options = {}
685
+
686
+ if self.enable_automatic_function_calling and stream:
687
+ raise NotImplementedError(
688
+ "Unsupported configuration: The `google.generativeai` SDK currently does not support the combination of `stream=True` and `enable_automatic_function_calling=True`."
689
+ )
690
+
691
+ tools_lib = self.model._get_tools_lib(tools)
692
+
693
+ content = content_types.to_content(content)
694
+
695
+ if not content.role:
696
+ content.role = _USER_ROLE
697
+
698
+ history = self.history[:]
699
+ history.append(content)
700
+
701
+ generation_config = generation_types.to_generation_config_dict(generation_config)
702
+ if generation_config.get("candidate_count", 1) > 1:
703
+ raise ValueError(
704
+ "Invalid configuration: The chat functionality does not support `candidate_count` greater than 1."
705
+ )
706
+
707
+ response = await self.model.generate_content_async(
708
+ contents=history,
709
+ generation_config=generation_config,
710
+ safety_settings=safety_settings,
711
+ stream=stream,
712
+ tools=tools_lib,
713
+ tool_config=tool_config,
714
+ request_options=request_options,
715
+ )
716
+
717
+ self._check_response(response=response, stream=stream)
718
+
719
+ if self.enable_automatic_function_calling and tools_lib is not None:
720
+ self.history, content, response = await self._handle_afc_async(
721
+ response=response,
722
+ history=history,
723
+ generation_config=generation_config,
724
+ safety_settings=safety_settings,
725
+ stream=stream,
726
+ tools_lib=tools_lib,
727
+ request_options=request_options,
728
+ )
729
+
730
+ self._last_sent = content
731
+ self._last_received = response
732
+
733
+ return response
734
+
735
+ async def _handle_afc_async(
736
+ self,
737
+ *,
738
+ response,
739
+ history,
740
+ generation_config,
741
+ safety_settings,
742
+ stream,
743
+ tools_lib,
744
+ request_options,
745
+ ) -> tuple[list[protos.Content], protos.Content, generation_types.BaseGenerateContentResponse]:
746
+
747
+ while function_calls := self._get_function_calls(response):
748
+ if not all(callable(tools_lib[fc]) for fc in function_calls):
749
+ break
750
+ history.append(response.candidates[0].content)
751
+
752
+ function_response_parts: list[protos.Part] = []
753
+ for fc in function_calls:
754
+ fr = tools_lib(fc)
755
+ assert fr is not None, (
756
+ "Unexpected state: The function reference (fr) should never be None. It should only return None if the declaration "
757
+ "is not callable, which is checked earlier in the code."
758
+ )
759
+ function_response_parts.append(fr)
760
+
761
+ send = protos.Content(role=_USER_ROLE, parts=function_response_parts)
762
+ history.append(send)
763
+
764
+ response = await self.model.generate_content_async(
765
+ contents=history,
766
+ generation_config=generation_config,
767
+ safety_settings=safety_settings,
768
+ stream=stream,
769
+ tools=tools_lib,
770
+ request_options=request_options,
771
+ )
772
+
773
+ self._check_response(response=response, stream=stream)
774
+
775
+ *history, content = history
776
+ return history, content, response
777
+
778
+ def __copy__(self):
779
+ return ChatSession(
780
+ model=self.model,
781
+ # Be sure the copy doesn't share the history.
782
+ history=list(self.history),
783
+ )
784
+
785
+ def rewind(self) -> tuple[protos.Content, protos.Content]:
786
+ """Removes the last request/response pair from the chat history."""
787
+ if self._last_received is None:
788
+ result = self._history.pop(-2), self._history.pop()
789
+ return result
790
+ else:
791
+ result = self._last_sent, self._last_received.candidates[0].content
792
+ self._last_sent = None
793
+ self._last_received = None
794
+ return result
795
+
796
+ @property
797
+ def last(self) -> generation_types.BaseGenerateContentResponse | None:
798
+ """returns the last received `genai.GenerateContentResponse`"""
799
+ return self._last_received
800
+
801
+ @property
802
+ def history(self) -> list[protos.Content]:
803
+ """The chat history."""
804
+ last = self._last_received
805
+ if last is None:
806
+ return self._history
807
+
808
+ if last.candidates[0].finish_reason not in (
809
+ protos.Candidate.FinishReason.FINISH_REASON_UNSPECIFIED,
810
+ protos.Candidate.FinishReason.STOP,
811
+ protos.Candidate.FinishReason.MAX_TOKENS,
812
+ ):
813
+ error = generation_types.StopCandidateException(last.candidates[0])
814
+ last._error = error
815
+
816
+ if last._error is not None:
817
+ raise generation_types.BrokenResponseError(
818
+ "Unable to build a coherent chat history due to a broken streaming response. "
819
+ "Refer to the previous exception for details. "
820
+ "To inspect the last response object, use `chat.last`. "
821
+ "To remove the last request/response `Content` objects from the chat, "
822
+ "call `last_send, last_received = chat.rewind()` and continue without it."
823
+ ) from last._error
824
+
825
+ sent = self._last_sent
826
+ received = last.candidates[0].content
827
+ if not received.role:
828
+ received.role = _MODEL_ROLE
829
+ self._history.extend([sent, received])
830
+
831
+ self._last_sent = None
832
+ self._last_received = None
833
+
834
+ return self._history
835
+
836
+ @history.setter
837
+ def history(self, history):
838
+ self._history = content_types.to_contents(history)
839
+ self._last_sent = None
840
+ self._last_received = None
841
+
842
+ def __repr__(self) -> str:
843
+ _dict_repr = reprlib.Repr()
844
+ _model = str(self.model).replace("\n", "\n" + " " * 4)
845
+
846
+ def content_repr(x):
847
+ return f"protos.Content({_dict_repr.repr(type(x).to_dict(x))})"
848
+
849
+ try:
850
+ history = list(self.history)
851
+ except (generation_types.BrokenResponseError, generation_types.IncompleteIterationError):
852
+ history = list(self._history)
853
+
854
+ if self._last_sent is not None:
855
+ history.append(self._last_sent)
856
+ history = [content_repr(x) for x in history]
857
+
858
+ last_received = self._last_received
859
+ if last_received is not None:
860
+ if last_received._error is not None:
861
+ history.append("<STREAMING ERROR>")
862
+ else:
863
+ history.append("<STREAMING IN PROGRESS>")
864
+
865
+ _history = ",\n " + f"history=[{', '.join(history)}]\n)"
866
+
867
+ return (
868
+ textwrap.dedent(
869
+ f"""\
870
+ ChatSession(
871
+ model="""
872
+ )
873
+ + _model
874
+ + _history
875
+ )
.venv/lib/python3.11/site-packages/google/generativeai/models.py ADDED
@@ -0,0 +1,467 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright 2023 Google LLC
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ from __future__ import annotations
16
+
17
+ import typing
18
+ from typing import Any, Literal
19
+
20
+ import google.ai.generativelanguage as glm
21
+
22
+ from google.generativeai import protos
23
+ from google.generativeai import operations
24
+ from google.generativeai.client import get_default_model_client
25
+ from google.generativeai.types import model_types
26
+ from google.generativeai.types import helper_types
27
+ from google.api_core import operation
28
+ from google.api_core import protobuf_helpers
29
+ from google.protobuf import field_mask_pb2
30
+ from google.generativeai.utils import flatten_update_paths
31
+
32
+
33
+ def get_model(
34
+ name: model_types.AnyModelNameOptions,
35
+ *,
36
+ client=None,
37
+ request_options: helper_types.RequestOptionsType | None = None,
38
+ ) -> model_types.Model | model_types.TunedModel:
39
+ """Calls the API to fetch a model by name.
40
+
41
+ ```
42
+ import pprint
43
+ model = genai.get_model('models/gemini-1.5-flash')
44
+ pprint.pprint(model)
45
+ ```
46
+
47
+ Args:
48
+ name: The name of the model to fetch. Should start with `models/`
49
+ client: The client to use.
50
+ request_options: Options for the request.
51
+
52
+ Returns:
53
+ A `types.Model`
54
+ """
55
+ name = model_types.make_model_name(name)
56
+ if name.startswith("models/"):
57
+ return get_base_model(name, client=client, request_options=request_options)
58
+ elif name.startswith("tunedModels/"):
59
+ return get_tuned_model(name, client=client, request_options=request_options)
60
+ else:
61
+ raise ValueError(
62
+ f"Invalid model name: Model names must start with `models/` or `tunedModels/`. Received: {name}"
63
+ )
64
+
65
+
66
+ def get_base_model(
67
+ name: model_types.BaseModelNameOptions,
68
+ *,
69
+ client=None,
70
+ request_options: helper_types.RequestOptionsType | None = None,
71
+ ) -> model_types.Model:
72
+ """Calls the API to fetch a base model by name.
73
+
74
+ ```
75
+ import pprint
76
+ model = genai.get_base_model('models/chat-bison-001')
77
+ pprint.pprint(model)
78
+ ```
79
+
80
+ Args:
81
+ name: The name of the model to fetch. Should start with `models/`
82
+ client: The client to use.
83
+ request_options: Options for the request.
84
+
85
+ Returns:
86
+ A `types.Model`.
87
+ """
88
+ if request_options is None:
89
+ request_options = {}
90
+
91
+ if client is None:
92
+ client = get_default_model_client()
93
+
94
+ name = model_types.make_model_name(name)
95
+ if not name.startswith("models/"):
96
+ raise ValueError(
97
+ f"Invalid model name: Base model names must start with `models/`. Received: {name}"
98
+ )
99
+
100
+ result = client.get_model(name=name, **request_options)
101
+ result = type(result).to_dict(result)
102
+ return model_types.Model(**result)
103
+
104
+
105
+ def get_tuned_model(
106
+ name: model_types.TunedModelNameOptions,
107
+ *,
108
+ client=None,
109
+ request_options: helper_types.RequestOptionsType | None = None,
110
+ ) -> model_types.TunedModel:
111
+ """Calls the API to fetch a tuned model by name.
112
+
113
+ ```
114
+ import pprint
115
+ model = genai.get_tuned_model('tunedModels/gemini-1.5-flash')
116
+ pprint.pprint(model)
117
+ ```
118
+
119
+ Args:
120
+ name: The name of the model to fetch. Should start with `tunedModels/`
121
+ client: The client to use.
122
+ request_options: Options for the request.
123
+
124
+ Returns:
125
+ A `types.TunedModel`.
126
+ """
127
+ if request_options is None:
128
+ request_options = {}
129
+
130
+ if client is None:
131
+ client = get_default_model_client()
132
+
133
+ name = model_types.make_model_name(name)
134
+
135
+ if not name.startswith("tunedModels/"):
136
+ raise ValueError(
137
+ f"Invalid model name: Tuned model names must start with `tunedModels/`. Received: {name}"
138
+ )
139
+
140
+ result = client.get_tuned_model(name=name, **request_options)
141
+
142
+ return model_types.decode_tuned_model(result)
143
+
144
+
145
+ def get_base_model_name(
146
+ model: model_types.AnyModelNameOptions, client: glm.ModelServiceClient | None = None
147
+ ):
148
+ """Calls the API to fetch the base model name of a model."""
149
+
150
+ if isinstance(model, str):
151
+ if model.startswith("tunedModels/"):
152
+ model = get_model(model, client=client)
153
+ base_model = model.base_model
154
+ else:
155
+ base_model = model
156
+ elif isinstance(model, model_types.TunedModel):
157
+ base_model = model.base_model
158
+ elif isinstance(model, model_types.Model):
159
+ base_model = model.name
160
+ elif isinstance(model, protos.Model):
161
+ base_model = model.name
162
+ elif isinstance(model, protos.TunedModel):
163
+ base_model = getattr(model, "base_model", None)
164
+ if not base_model:
165
+ base_model = model.tuned_model_source.base_model
166
+ else:
167
+ raise TypeError(
168
+ f"Invalid model: The provided model '{model}' is not recognized or supported. "
169
+ "Supported types are: str, model_types.TunedModel, model_types.Model, protos.Model, and protos.TunedModel."
170
+ )
171
+
172
+ return base_model
173
+
174
+
175
+ def list_models(
176
+ *,
177
+ page_size: int | None = 50,
178
+ client: glm.ModelServiceClient | None = None,
179
+ request_options: helper_types.RequestOptionsType | None = None,
180
+ ) -> model_types.ModelsIterable:
181
+ """Calls the API to list all available models.
182
+
183
+ ```
184
+ import pprint
185
+ for model in genai.list_models():
186
+ pprint.pprint(model)
187
+ ```
188
+
189
+ Args:
190
+ page_size: How many `types.Models` to fetch per page (api call).
191
+ client: You may pass a `glm.ModelServiceClient` instead of using the default client.
192
+ request_options: Options for the request.
193
+
194
+ Yields:
195
+ `types.Model` objects.
196
+
197
+ """
198
+ if request_options is None:
199
+ request_options = {}
200
+
201
+ if client is None:
202
+ client = get_default_model_client()
203
+
204
+ for model in client.list_models(page_size=page_size, **request_options):
205
+ model = type(model).to_dict(model)
206
+ yield model_types.Model(**model)
207
+
208
+
209
+ def list_tuned_models(
210
+ *,
211
+ page_size: int | None = 50,
212
+ client: glm.ModelServiceClient | None = None,
213
+ request_options: helper_types.RequestOptionsType | None = None,
214
+ ) -> model_types.TunedModelsIterable:
215
+ """Calls the API to list all tuned models.
216
+
217
+ ```
218
+ import pprint
219
+ for model in genai.list_tuned_models():
220
+ pprint.pprint(model)
221
+ ```
222
+
223
+ Args:
224
+ page_size: How many `types.Models` to fetch per page (api call).
225
+ client: You may pass a `glm.ModelServiceClient` instead of using the default client.
226
+ request_options: Options for the request.
227
+
228
+ Yields:
229
+ `types.TunedModel` objects.
230
+ """
231
+ if request_options is None:
232
+ request_options = {}
233
+
234
+ if client is None:
235
+ client = get_default_model_client()
236
+
237
+ for model in client.list_tuned_models(
238
+ page_size=page_size,
239
+ **request_options,
240
+ ):
241
+ model = type(model).to_dict(model)
242
+ yield model_types.decode_tuned_model(model)
243
+
244
+
245
+ def create_tuned_model(
246
+ source_model: model_types.AnyModelNameOptions,
247
+ training_data: model_types.TuningDataOptions,
248
+ *,
249
+ id: str | None = None,
250
+ display_name: str | None = None,
251
+ description: str | None = None,
252
+ temperature: float | None = None,
253
+ top_p: float | None = None,
254
+ top_k: int | None = None,
255
+ epoch_count: int | None = None,
256
+ batch_size: int | None = None,
257
+ learning_rate: float | None = None,
258
+ input_key: str = "text_input",
259
+ output_key: str = "output",
260
+ client: glm.ModelServiceClient | None = None,
261
+ request_options: helper_types.RequestOptionsType | None = None,
262
+ ) -> operations.CreateTunedModelOperation:
263
+ """Calls the API to initiate a tuning process that optimizes a model for specific data, returning an operation object to track and manage the tuning progress.
264
+
265
+ Since tuning a model can take significant time, this API doesn't wait for the tuning to complete.
266
+ Instead, it returns a `google.api_core.operation.Operation` object that lets you check on the
267
+ status of the tuning job, or wait for it to complete, and check the result.
268
+
269
+ After the job completes you can either find the resulting `TunedModel` object in
270
+ `Operation.result()` or `palm.list_tuned_models` or `palm.get_tuned_model(model_id)`.
271
+
272
+ ```
273
+ my_id = "my-tuned-model-id"
274
+ operation = palm.create_tuned_model(
275
+ id = my_id,
276
+ source_model="models/text-bison-001",
277
+ training_data=[{'text_input': 'example input', 'output': 'example output'},...]
278
+ )
279
+ tuned_model=operation.result() # Wait for tuning to finish
280
+
281
+ palm.generate_text(f"tunedModels/{my_id}", prompt="...")
282
+ ```
283
+
284
+ Args:
285
+ source_model: The name of the model to tune.
286
+ training_data: The dataset to tune the model on. This must be either:
287
+ * A `protos.Dataset`, or
288
+ * An `Iterable` of:
289
+ *`protos.TuningExample`,
290
+ * `{'text_input': text_input, 'output': output}` dicts
291
+ * `(text_input, output)` tuples.
292
+ * A `Mapping` of `Iterable[str]` - use `input_key` and `output_key` to choose which
293
+ columns to use as the input/output
294
+ * A csv file (will be read with `pd.read_csv` and handles as a `Mapping`
295
+ above). This can be:
296
+ * A local path as a `str` or `pathlib.Path`.
297
+ * A url for a csv file.
298
+ * The url of a Google Sheets file.
299
+ * A JSON file - Its contents will be handled either as an `Iterable` or `Mapping`
300
+ above. This can be:
301
+ * A local path as a `str` or `pathlib.Path`.
302
+ id: The model identifier, used to refer to the model in the API
303
+ `tunedModels/{id}`. Must be unique.
304
+ display_name: A human-readable name for display.
305
+ description: A description of the tuned model.
306
+ temperature: The default temperature for the tuned model, see `types.Model` for details.
307
+ top_p: The default `top_p` for the model, see `types.Model` for details.
308
+ top_k: The default `top_k` for the model, see `types.Model` for details.
309
+ epoch_count: The number of tuning epochs to run. An epoch is a pass over the whole dataset.
310
+ batch_size: The number of examples to use in each training batch.
311
+ learning_rate: The step size multiplier for the gradient updates.
312
+ client: Which client to use.
313
+ request_options: Options for the request.
314
+
315
+ Returns:
316
+ A [`google.api_core.operation.Operation`](https://googleapis.dev/python/google-api-core/latest/operation.html)
317
+ """
318
+ if request_options is None:
319
+ request_options = {}
320
+
321
+ if client is None:
322
+ client = get_default_model_client()
323
+
324
+ source_model_name = model_types.make_model_name(source_model)
325
+ base_model_name = get_base_model_name(source_model)
326
+ if source_model_name.startswith("models/"):
327
+ source_model = {"base_model": source_model_name}
328
+ elif source_model_name.startswith("tunedModels/"):
329
+ source_model = {
330
+ "tuned_model_source": {
331
+ "tuned_model": source_model_name,
332
+ "base_model": base_model_name,
333
+ }
334
+ }
335
+ else:
336
+ raise ValueError(
337
+ f"Invalid model name: The provided model '{source_model}' does not match any known model patterns such as 'models/' or 'tunedModels/'"
338
+ )
339
+
340
+ training_data = model_types.encode_tuning_data(
341
+ training_data, input_key=input_key, output_key=output_key
342
+ )
343
+
344
+ hyperparameters = protos.Hyperparameters(
345
+ epoch_count=epoch_count,
346
+ batch_size=batch_size,
347
+ learning_rate=learning_rate,
348
+ )
349
+ tuning_task = protos.TuningTask(
350
+ training_data=training_data,
351
+ hyperparameters=hyperparameters,
352
+ )
353
+
354
+ tuned_model = protos.TunedModel(
355
+ **source_model,
356
+ display_name=display_name,
357
+ description=description,
358
+ temperature=temperature,
359
+ top_p=top_p,
360
+ top_k=top_k,
361
+ tuning_task=tuning_task,
362
+ )
363
+
364
+ operation = client.create_tuned_model(
365
+ dict(tuned_model_id=id, tuned_model=tuned_model), **request_options
366
+ )
367
+
368
+ return operations.CreateTunedModelOperation.from_core_operation(operation)
369
+
370
+
371
+ @typing.overload
372
+ def update_tuned_model(
373
+ tuned_model: protos.TunedModel,
374
+ updates: None = None,
375
+ *,
376
+ client: glm.ModelServiceClient | None = None,
377
+ request_options: helper_types.RequestOptionsType | None = None,
378
+ ) -> model_types.TunedModel:
379
+ pass
380
+
381
+
382
+ @typing.overload
383
+ def update_tuned_model(
384
+ tuned_model: str,
385
+ updates: dict[str, Any],
386
+ *,
387
+ client: glm.ModelServiceClient | None = None,
388
+ request_options: helper_types.RequestOptionsType | None = None,
389
+ ) -> model_types.TunedModel:
390
+ pass
391
+
392
+
393
+ def update_tuned_model(
394
+ tuned_model: str | protos.TunedModel,
395
+ updates: dict[str, Any] | None = None,
396
+ *,
397
+ client: glm.ModelServiceClient | None = None,
398
+ request_options: helper_types.RequestOptionsType | None = None,
399
+ ) -> model_types.TunedModel:
400
+ """Calls the API to push updates to a specified tuned model where only certain attributes are updatable."""
401
+
402
+ if request_options is None:
403
+ request_options = {}
404
+
405
+ if client is None:
406
+ client = get_default_model_client()
407
+
408
+ if isinstance(tuned_model, str):
409
+ name = tuned_model
410
+ if not isinstance(updates, dict):
411
+ raise TypeError(
412
+ f"Invalid argument type: In the function `update_tuned_model(name:str, updates: dict)`, the `updates` argument must be of type `dict`. Received type: {type(updates).__name__}."
413
+ )
414
+
415
+ tuned_model = client.get_tuned_model(name=name, **request_options)
416
+
417
+ updates = flatten_update_paths(updates)
418
+ field_mask = field_mask_pb2.FieldMask()
419
+ for path in updates.keys():
420
+ field_mask.paths.append(path)
421
+ for path, value in updates.items():
422
+ _apply_update(tuned_model, path, value)
423
+ elif isinstance(tuned_model, protos.TunedModel):
424
+ if updates is not None:
425
+ raise ValueError(
426
+ "Invalid argument: When calling `update_tuned_model(tuned_model:protos.TunedModel, updates=None)`, "
427
+ "the `updates` argument must not be set."
428
+ )
429
+
430
+ name = tuned_model.name
431
+ was = client.get_tuned_model(name=name)
432
+ field_mask = protobuf_helpers.field_mask(was._pb, tuned_model._pb)
433
+ else:
434
+ raise TypeError(
435
+ "Invalid argument type: In the function `update_tuned_model(tuned_model:dict|protos.TunedModel)`, the "
436
+ f"`tuned_model` argument must be of type `dict` or `protos.TunedModel`. Received type: {type(tuned_model).__name__}."
437
+ )
438
+
439
+ result = client.update_tuned_model(
440
+ protos.UpdateTunedModelRequest(tuned_model=tuned_model, update_mask=field_mask),
441
+ **request_options,
442
+ )
443
+ return model_types.decode_tuned_model(result)
444
+
445
+
446
+ def _apply_update(thing, path, value):
447
+ parts = path.split(".")
448
+ for part in parts[:-1]:
449
+ thing = getattr(thing, part)
450
+ setattr(thing, parts[-1], value)
451
+
452
+
453
+ def delete_tuned_model(
454
+ tuned_model: model_types.TunedModelNameOptions,
455
+ client: glm.ModelServiceClient | None = None,
456
+ request_options: helper_types.RequestOptionsType | None = None,
457
+ ) -> None:
458
+ """Calls the API to delete a specified tuned model"""
459
+
460
+ if request_options is None:
461
+ request_options = {}
462
+
463
+ if client is None:
464
+ client = get_default_model_client()
465
+
466
+ name = model_types.make_model_name(tuned_model)
467
+ client.delete_tuned_model(name=name, **request_options)
.venv/lib/python3.11/site-packages/google/generativeai/notebook/__init__.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright 2023 Google LLC
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Notebook extensions for Generative AI."""
16
+
17
+
18
+ def load_ipython_extension(ipython):
19
+ """Register the Colab Magic extension to support %load_ext."""
20
+ # pylint: disable-next=g-import-not-at-top
21
+ from google.generativeai.notebook import magics
22
+
23
+ ipython.register_magics(magics.Magics)
24
+
25
+ # Since we're in an interactive environment, make the tables prettier.
26
+ try:
27
+ # pylint: disable-next=g-import-not-at-top
28
+ from google import colab # type: ignore
29
+
30
+ colab.data_table.enable_dataframe_formatter()
31
+ except ImportError:
32
+ pass
.venv/lib/python3.11/site-packages/google/generativeai/notebook/__pycache__/command.cpython-311.pyc ADDED
Binary file (1.77 kB). View file
 
.venv/lib/python3.11/site-packages/google/generativeai/notebook/__pycache__/compare_cmd.cpython-311.pyc ADDED
Binary file (2.84 kB). View file
 
.venv/lib/python3.11/site-packages/google/generativeai/notebook/__pycache__/eval_cmd.cpython-311.pyc ADDED
Binary file (3.01 kB). View file
 
.venv/lib/python3.11/site-packages/google/generativeai/notebook/__pycache__/flag_def.cpython-311.pyc ADDED
Binary file (21.9 kB). View file
 
.venv/lib/python3.11/site-packages/google/generativeai/notebook/__pycache__/input_utils.cpython-311.pyc ADDED
Binary file (3.62 kB). View file
 
.venv/lib/python3.11/site-packages/google/generativeai/notebook/__pycache__/magics_engine.cpython-311.pyc ADDED
Binary file (5.64 kB). View file
 
.venv/lib/python3.11/site-packages/google/generativeai/notebook/__pycache__/parsed_args_lib.cpython-311.pyc ADDED
Binary file (3.21 kB). View file
 
.venv/lib/python3.11/site-packages/google/generativeai/notebook/__pycache__/run_cmd.cpython-311.pyc ADDED
Binary file (3 kB). View file
 
.venv/lib/python3.11/site-packages/google/generativeai/notebook/__pycache__/sheets_id.cpython-311.pyc ADDED
Binary file (4.92 kB). View file
 
.venv/lib/python3.11/site-packages/google/generativeai/notebook/__pycache__/sheets_utils.cpython-311.pyc ADDED
Binary file (6.03 kB). View file
 
.venv/lib/python3.11/site-packages/google/generativeai/notebook/cmd_line_parser.py ADDED
@@ -0,0 +1,557 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright 2023 Google LLC
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Parses an LLM command line."""
16
+ from __future__ import annotations
17
+
18
+ import argparse
19
+ import shlex
20
+ import sys
21
+ from typing import AbstractSet, Any, Callable, MutableMapping, Sequence
22
+
23
+ from google.generativeai.notebook import argument_parser
24
+ from google.generativeai.notebook import flag_def
25
+ from google.generativeai.notebook import input_utils
26
+ from google.generativeai.notebook import model_registry
27
+ from google.generativeai.notebook import output_utils
28
+ from google.generativeai.notebook import parsed_args_lib
29
+ from google.generativeai.notebook import post_process_utils
30
+ from google.generativeai.notebook import py_utils
31
+ from google.generativeai.notebook import sheets_utils
32
+ from google.generativeai.notebook.lib import llm_function
33
+ from google.generativeai.notebook.lib import llmfn_inputs_source
34
+ from google.generativeai.notebook.lib import llmfn_outputs
35
+ from google.generativeai.notebook.lib import model as model_lib
36
+
37
+
38
+ _MIN_CANDIDATE_COUNT = 1
39
+ _MAX_CANDIDATE_COUNT = 8
40
+
41
+
42
+ def _validate_input_source_against_placeholders(
43
+ source: llmfn_inputs_source.LLMFnInputsSource,
44
+ placeholders: AbstractSet[str],
45
+ ) -> None:
46
+ for inputs in source.to_normalized_inputs():
47
+ for keyword in placeholders:
48
+ if keyword not in inputs:
49
+ raise ValueError('Placeholder "{}" not found in input'.format(keyword))
50
+
51
+
52
+ def _get_resolve_input_from_py_var_fn(
53
+ placeholders: AbstractSet[str] | None,
54
+ ) -> Callable[[str], llmfn_inputs_source.LLMFnInputsSource]:
55
+ def _fn(var_name: str) -> llmfn_inputs_source.LLMFnInputsSource:
56
+ source = input_utils.get_inputs_source_from_py_var(var_name)
57
+ if placeholders:
58
+ _validate_input_source_against_placeholders(source, placeholders)
59
+ return source
60
+
61
+ return _fn
62
+
63
+
64
+ def _resolve_compare_fn_var(
65
+ name: str,
66
+ ) -> tuple[str, parsed_args_lib.TextResultCompareFn]:
67
+ """Resolves a value passed into --compare_fn."""
68
+ fn = py_utils.get_py_var(name)
69
+ if not isinstance(fn, Callable):
70
+ raise ValueError('Variable "{}" does not contain a Callable object'.format(name))
71
+
72
+ return name, fn
73
+
74
+
75
+ def _resolve_ground_truth_var(name: str) -> Sequence[str]:
76
+ """Resolves a value passed into --ground_truth."""
77
+ value = py_utils.get_py_var(name)
78
+
79
+ # "str" and "bytes" are also Sequences but we want an actual Sequence of
80
+ # strings, like a list.
81
+ if not isinstance(value, Sequence) or isinstance(value, str) or isinstance(value, bytes):
82
+ raise ValueError('Variable "{}" does not contain a Sequence of strings'.format(name))
83
+ for x in value:
84
+ if not isinstance(x, str):
85
+ raise ValueError('Variable "{}" does not contain a Sequence of strings'.format(name))
86
+ return value
87
+
88
+
89
+ def _get_resolve_sheets_inputs_fn(
90
+ placeholders: AbstractSet[str] | None,
91
+ ) -> Callable[[str], llmfn_inputs_source.LLMFnInputsSource]:
92
+ def _fn(value: str) -> llmfn_inputs_source.LLMFnInputsSource:
93
+ sheets_id = sheets_utils.get_sheets_id_from_str(value)
94
+ source = sheets_utils.SheetsInputs(sheets_id)
95
+ if placeholders:
96
+ _validate_input_source_against_placeholders(source, placeholders)
97
+ return source
98
+
99
+ return _fn
100
+
101
+
102
+ def _resolve_sheets_outputs(value: str) -> llmfn_outputs.LLMFnOutputsSink:
103
+ sheets_id = sheets_utils.get_sheets_id_from_str(value)
104
+ return sheets_utils.SheetsOutputs(sheets_id)
105
+
106
+
107
+ def _add_model_flags(
108
+ parser: argparse.ArgumentParser,
109
+ ) -> None:
110
+ """Adds flags that are related to model selection and config."""
111
+ flag_def.EnumFlagDef(
112
+ name="model_type",
113
+ short_name="mt",
114
+ enum_type=model_registry.ModelName,
115
+ default_value=model_registry.ModelRegistry.DEFAULT_MODEL,
116
+ help_msg="The type of model to use.",
117
+ ).add_argument_to_parser(parser)
118
+
119
+ def _check_is_greater_than_or_equal_to_zero(x: float) -> float:
120
+ if x < 0:
121
+ raise ValueError("Value should be greater than or equal to zero, got {}".format(x))
122
+ return x
123
+
124
+ flag_def.SingleValueFlagDef(
125
+ name="temperature",
126
+ short_name="t",
127
+ parse_type=float,
128
+ # Use None for default value to indicate that this will use the default
129
+ # value in Text service.
130
+ default_value=None,
131
+ parse_to_dest_type_fn=_check_is_greater_than_or_equal_to_zero,
132
+ help_msg=(
133
+ "Controls the randomness of the output. Must be positive. Typical"
134
+ " values are in the range: [0.0, 1.0]. Higher values produce a more"
135
+ " random and varied response. A temperature of zero will be"
136
+ " deterministic."
137
+ ),
138
+ ).add_argument_to_parser(parser)
139
+
140
+ flag_def.SingleValueFlagDef(
141
+ name="model",
142
+ short_name="m",
143
+ default_value=None,
144
+ help_msg=(
145
+ "The name of the model to use. If not provided, a default model will" " be used."
146
+ ),
147
+ ).add_argument_to_parser(parser)
148
+
149
+ def _check_candidate_count_range(x: Any) -> int:
150
+ if x < _MIN_CANDIDATE_COUNT or x > _MAX_CANDIDATE_COUNT:
151
+ raise ValueError(
152
+ "Value should be in the range [{}, {}], got {}".format(
153
+ _MIN_CANDIDATE_COUNT, _MAX_CANDIDATE_COUNT, x
154
+ )
155
+ )
156
+ return int(x)
157
+
158
+ flag_def.SingleValueFlagDef(
159
+ name="candidate_count",
160
+ short_name="cc",
161
+ parse_type=int,
162
+ # Use None for default value to indicate that this will use the default
163
+ # value in Text service.
164
+ default_value=None,
165
+ parse_to_dest_type_fn=_check_candidate_count_range,
166
+ help_msg="The number of candidates to produce.",
167
+ ).add_argument_to_parser(parser)
168
+
169
+ flag_def.BooleanFlagDef(
170
+ name="unique",
171
+ help_msg="Whether to dedupe candidates returned by the model.",
172
+ ).add_argument_to_parser(parser)
173
+
174
+
175
+ def _add_input_flags(
176
+ parser: argparse.ArgumentParser,
177
+ placeholders: AbstractSet[str] | None,
178
+ ) -> None:
179
+ """Adds flags to read inputs from a Python variable or Sheets."""
180
+ flag_def.MultiValuesFlagDef(
181
+ name="inputs",
182
+ short_name="i",
183
+ dest_type=llmfn_inputs_source.LLMFnInputsSource,
184
+ parse_to_dest_type_fn=_get_resolve_input_from_py_var_fn(placeholders),
185
+ help_msg=(
186
+ "Optional names of Python variables containing inputs to use to"
187
+ " instantiate a prompt. The variable must be either: a dictionary"
188
+ " {'key1': ['val1', 'val2'] ...}, or an instance of LLMFnInputsSource"
189
+ " such as SheetsInput."
190
+ ),
191
+ ).add_argument_to_parser(parser)
192
+
193
+ flag_def.MultiValuesFlagDef(
194
+ name="sheets_input_names",
195
+ short_name="si",
196
+ dest_type=llmfn_inputs_source.LLMFnInputsSource,
197
+ parse_to_dest_type_fn=_get_resolve_sheets_inputs_fn(placeholders),
198
+ help_msg=(
199
+ "Optional names of Google Sheets to read inputs from. This is"
200
+ " equivalent to using --inputs with the names of variables that are"
201
+ " instances of SheetsInputs, just more convenient to use."
202
+ ),
203
+ ).add_argument_to_parser(parser)
204
+
205
+
206
+ def _add_output_flags(
207
+ parser: argparse.ArgumentParser,
208
+ ) -> None:
209
+ """Adds flags to write outputs to a Python variable."""
210
+ flag_def.MultiValuesFlagDef(
211
+ name="outputs",
212
+ short_name="o",
213
+ dest_type=llmfn_outputs.LLMFnOutputsSink,
214
+ parse_to_dest_type_fn=output_utils.get_outputs_sink_from_py_var,
215
+ help_msg=(
216
+ "Optional names of Python variables to output to. If the Python"
217
+ " variable has not already been defined, it will be created. If the"
218
+ " variable is defined and is an instance of LLMFnOutputsSink, the"
219
+ " outputs will be written through the sink's write_outputs() method."
220
+ ),
221
+ ).add_argument_to_parser(parser)
222
+
223
+ flag_def.MultiValuesFlagDef(
224
+ name="sheets_output_names",
225
+ short_name="so",
226
+ dest_type=llmfn_outputs.LLMFnOutputsSink,
227
+ parse_to_dest_type_fn=_resolve_sheets_outputs,
228
+ help_msg=(
229
+ "Optional names of Google Sheets to write inputs to. This is"
230
+ " equivalent to using --outputs with the names of variables that are"
231
+ " instances of SheetsOutputs, just more convenient to use."
232
+ ),
233
+ ).add_argument_to_parser(parser)
234
+
235
+
236
+ def _add_compare_flags(
237
+ parser: argparse.ArgumentParser,
238
+ ) -> None:
239
+ flag_def.MultiValuesFlagDef(
240
+ name="compare_fn",
241
+ dest_type=tuple,
242
+ parse_to_dest_type_fn=_resolve_compare_fn_var,
243
+ help_msg=(
244
+ "An optional function that takes two inputs: (lhs_result, rhs_result)"
245
+ " which are the results of the left- and right-hand side functions. "
246
+ "Multiple comparison functions can be provided."
247
+ ),
248
+ ).add_argument_to_parser(parser)
249
+
250
+
251
+ def _add_eval_flags(
252
+ parser: argparse.ArgumentParser,
253
+ ) -> None:
254
+ flag_def.SingleValueFlagDef(
255
+ name="ground_truth",
256
+ required=True,
257
+ dest_type=Sequence,
258
+ parse_to_dest_type_fn=_resolve_ground_truth_var,
259
+ help_msg=(
260
+ "A variable containing a Sequence of strings representing the ground"
261
+ " truth that the output of this cell will be compared against. It"
262
+ " should have the same number of entries as inputs."
263
+ ),
264
+ ).add_argument_to_parser(parser)
265
+
266
+
267
+ def _create_run_parser(
268
+ parser: argparse.ArgumentParser,
269
+ placeholders: AbstractSet[str] | None,
270
+ ) -> None:
271
+ """Adds flags for the `run` command.
272
+
273
+ `run` sends one or more prompts to a model.
274
+
275
+ Args:
276
+ parser: The parser to which flags will be added.
277
+ placeholders: Placeholders from prompts in the cell contents.
278
+ """
279
+ _add_model_flags(parser)
280
+ _add_input_flags(parser, placeholders)
281
+ _add_output_flags(parser)
282
+
283
+
284
+ def _create_compile_parser(
285
+ parser: argparse.ArgumentParser,
286
+ ) -> None:
287
+ """Adds flags for the compile command.
288
+
289
+ `compile` "compiles" a prompt and model call into a callable function.
290
+
291
+ Args:
292
+ parser: The parser to which flags will be added.
293
+ """
294
+
295
+ # Add a positional argument for "compile_save_name".
296
+ def _compile_save_name_fn(var_name: str) -> str:
297
+ try:
298
+ py_utils.validate_var_name(var_name)
299
+ except ValueError as e:
300
+ # Re-raise as ArgumentError to preserve the original error message.
301
+ raise argparse.ArgumentError(None, "{}".format(e)) from e
302
+ return var_name
303
+
304
+ save_name_help = "The name of a Python variable to save the compiled function to."
305
+ parser.add_argument("compile_save_name", help=save_name_help, type=_compile_save_name_fn)
306
+ _add_model_flags(parser)
307
+
308
+
309
+ def _create_compare_parser(
310
+ parser: argparse.ArgumentParser,
311
+ placeholders: AbstractSet[str] | None,
312
+ ) -> None:
313
+ """Adds flags for the compare command.
314
+
315
+ Args:
316
+ parser: The parser to which flags will be added.
317
+ placeholders: Placeholders from prompts in the compiled functions.
318
+ """
319
+
320
+ # Add positional arguments.
321
+ def _resolve_llm_function_fn(
322
+ var_name: str,
323
+ ) -> tuple[str, llm_function.LLMFunction]:
324
+ try:
325
+ py_utils.validate_var_name(var_name)
326
+ except ValueError as e:
327
+ # Re-raise as ArgumentError to preserve the original error message.
328
+ raise argparse.ArgumentError(None, "{}".format(e)) from e
329
+
330
+ fn = py_utils.get_py_var(var_name)
331
+ if not isinstance(fn, llm_function.LLMFunction):
332
+ raise argparse.ArgumentError(
333
+ None,
334
+ '{} is not a function created with the "compile" command'.format(var_name),
335
+ )
336
+ return var_name, fn
337
+
338
+ name_help = (
339
+ "The name of a Python variable containing a function previously created"
340
+ ' with the "compile" command.'
341
+ )
342
+ parser.add_argument("lhs_name_and_fn", help=name_help, type=_resolve_llm_function_fn)
343
+ parser.add_argument("rhs_name_and_fn", help=name_help, type=_resolve_llm_function_fn)
344
+
345
+ _add_input_flags(parser, placeholders)
346
+ _add_output_flags(parser)
347
+ _add_compare_flags(parser)
348
+
349
+
350
+ def _create_eval_parser(
351
+ parser: argparse.ArgumentParser,
352
+ placeholders: AbstractSet[str] | None,
353
+ ) -> None:
354
+ """Adds flags for the eval command.
355
+
356
+ Args:
357
+ parser: The parser to which flags will be added.
358
+ placeholders: Placeholders from prompts in the cell contents.
359
+ """
360
+ _add_model_flags(parser)
361
+ _add_input_flags(parser, placeholders)
362
+ _add_output_flags(parser)
363
+ _add_compare_flags(parser)
364
+ _add_eval_flags(parser)
365
+
366
+
367
+ def _create_parser(
368
+ placeholders: AbstractSet[str] | None,
369
+ ) -> argparse.ArgumentParser:
370
+ """Create the full parser."""
371
+ system_name = "llm"
372
+ description = "A system for interacting with LLMs."
373
+ epilog = ""
374
+
375
+ # Commands
376
+ parser = argument_parser.ArgumentParser(
377
+ prog=system_name,
378
+ description=description,
379
+ epilog=epilog,
380
+ )
381
+ subparsers = parser.add_subparsers(dest="cmd")
382
+ _create_run_parser(
383
+ subparsers.add_parser(parsed_args_lib.CommandName.RUN_CMD.value),
384
+ placeholders,
385
+ )
386
+ _create_compile_parser(subparsers.add_parser(parsed_args_lib.CommandName.COMPILE_CMD.value))
387
+ _create_compare_parser(
388
+ subparsers.add_parser(parsed_args_lib.CommandName.COMPARE_CMD.value),
389
+ placeholders,
390
+ )
391
+ _create_eval_parser(
392
+ subparsers.add_parser(parsed_args_lib.CommandName.EVAL_CMD.value),
393
+ placeholders,
394
+ )
395
+ return parser
396
+
397
+
398
+ def _validate_parsed_args(parsed_args: parsed_args_lib.ParsedArgs) -> None:
399
+ # If candidate_count is not set (i.e. is None), assuming the default value
400
+ # is 1.
401
+ if parsed_args.unique and (
402
+ parsed_args.model_args.candidate_count is None
403
+ or parsed_args.model_args.candidate_count == 1
404
+ ):
405
+ print(
406
+ '"--unique" works across candidates only: it should be used with'
407
+ " --candidate_count set to a value greater-than one."
408
+ )
409
+
410
+
411
+ class CmdLineParser:
412
+ """Implementation of Magics command line parser."""
413
+
414
+ # Commands
415
+ DEFAULT_CMD = parsed_args_lib.CommandName.RUN_CMD
416
+
417
+ # Post-processing operator.
418
+ PIPE_OP = "|"
419
+
420
+ @classmethod
421
+ def _split_post_processing_tokens(
422
+ cls,
423
+ tokens: Sequence[str],
424
+ ) -> tuple[Sequence[str], parsed_args_lib.PostProcessingTokens]:
425
+ """Splits inputs into the command and post processing tokens.
426
+
427
+ The command is represented as a sequence of tokens.
428
+ See comments on the PostProcessingTokens type alias.
429
+
430
+ E.g. Given: "run --temperature 0.5 | add_score | to_lower_case"
431
+ The command will be: ["run", "--temperature", "0.5"].
432
+ The post processing tokens will be: [["add_score"], ["to_lower_case"]]
433
+
434
+ Args:
435
+ tokens: The command line tokens.
436
+
437
+ Returns:
438
+ A tuple of (command line, post processing tokens).
439
+ """
440
+ split_tokens = []
441
+ start_idx: int | None = None
442
+ for token_num, token in enumerate(tokens):
443
+ if start_idx is None:
444
+ start_idx = token_num
445
+ if token == CmdLineParser.PIPE_OP:
446
+ split_tokens.append(tokens[start_idx:token_num] if start_idx is not None else [])
447
+ start_idx = None
448
+
449
+ # Add the remaining tokens after the last PIPE_OP.
450
+ split_tokens.append(tokens[start_idx:] if start_idx is not None else [])
451
+
452
+ return split_tokens[0], split_tokens[1:]
453
+
454
+ @classmethod
455
+ def _tokenize_line(
456
+ cls, line: str
457
+ ) -> tuple[Sequence[str], parsed_args_lib.PostProcessingTokens]:
458
+ """Parses `line` and returns command line and post processing tokens."""
459
+ # Check to make sure there is a command at the start. If not, add the
460
+ # default command to the list of tokens.
461
+ tokens = shlex.split(line)
462
+ if not tokens:
463
+ tokens = [CmdLineParser.DEFAULT_CMD.value]
464
+ first_token = tokens[0]
465
+ # Add default command if the first token is not the help token.
466
+ if not first_token[0].isalpha() and first_token not in ["-h", "--help"]:
467
+ tokens = [CmdLineParser.DEFAULT_CMD.value] + tokens
468
+ # Split line into tokens and post-processing
469
+ return CmdLineParser._split_post_processing_tokens(tokens)
470
+
471
+ @classmethod
472
+ def _get_model_args(
473
+ cls, parsed_results: MutableMapping[str, Any]
474
+ ) -> tuple[MutableMapping[str, Any], model_lib.ModelArguments]:
475
+ """Extracts fields for model args from `parsed_results`.
476
+
477
+ Keys specific to model arguments will be removed from `parsed_results`.
478
+
479
+ Args:
480
+ parsed_results: A dictionary of parsed arguments (from ArgumentParser). It
481
+ will be modified in place.
482
+
483
+ Returns:
484
+ A tuple of (updated parsed_results, model arguments).
485
+ """
486
+ model = parsed_results.pop("model", None)
487
+ temperature = parsed_results.pop("temperature", None)
488
+ candidate_count = parsed_results.pop("candidate_count", None)
489
+
490
+ model_args = model_lib.ModelArguments(
491
+ model=model,
492
+ temperature=temperature,
493
+ candidate_count=candidate_count,
494
+ )
495
+ return parsed_results, model_args
496
+
497
+ def parse_line(
498
+ self,
499
+ line: str,
500
+ placeholders: AbstractSet[str] | None = None,
501
+ ) -> tuple[parsed_args_lib.ParsedArgs, parsed_args_lib.PostProcessingTokens]:
502
+ """Parses the commandline and returns ParsedArgs and post-processing tokens.
503
+
504
+ Args:
505
+ line: The line to parse (usually contents from cell Magics).
506
+ placeholders: Placeholders from prompts in the cell contents.
507
+
508
+ Returns:
509
+ A tuple of (parsed_args, post_processing_tokens).
510
+ """
511
+ tokens, post_processing_tokens = CmdLineParser._tokenize_line(line)
512
+
513
+ parsed_args = self._get_parsed_args_from_cmd_line_tokens(
514
+ tokens=tokens, placeholders=placeholders
515
+ )
516
+
517
+ # Special-case for "compare" command: because the prompts are compiled into
518
+ # the left- and right-hand side functions rather than in the cell body, we
519
+ # cannot examine the cell body to get the placeholders.
520
+ #
521
+ # Instead we parse the command line twice: once to get the left- and right-
522
+ # functions, then we query the functions for their placeholders, then
523
+ # parse the commandline again to validate the inputs.
524
+ if parsed_args.cmd == parsed_args_lib.CommandName.COMPARE_CMD:
525
+ assert parsed_args.lhs_name_and_fn is not None
526
+ assert parsed_args.rhs_name_and_fn is not None
527
+ _, lhs_fn = parsed_args.lhs_name_and_fn
528
+ _, rhs_fn = parsed_args.rhs_name_and_fn
529
+ parsed_args = self._get_parsed_args_from_cmd_line_tokens(
530
+ tokens=tokens,
531
+ placeholders=frozenset(lhs_fn.get_placeholders()).union(rhs_fn.get_placeholders()),
532
+ )
533
+
534
+ _validate_parsed_args(parsed_args)
535
+
536
+ for expr in post_processing_tokens:
537
+ post_process_utils.validate_one_post_processing_expression(expr)
538
+
539
+ return parsed_args, post_processing_tokens
540
+
541
+ def _get_parsed_args_from_cmd_line_tokens(
542
+ self,
543
+ tokens: Sequence[str],
544
+ placeholders: AbstractSet[str] | None,
545
+ ) -> parsed_args_lib.ParsedArgs:
546
+ """Returns ParsedArgs from a tokenized command line."""
547
+ # Create a new parser to avoid reusing the temporary argparse.Namespace
548
+ # object.
549
+ results = _create_parser(placeholders).parse_args(tokens)
550
+
551
+ results_dict = vars(results)
552
+ results_dict["cmd"] = parsed_args_lib.CommandName(results_dict["cmd"])
553
+
554
+ results_dict, model_args = CmdLineParser._get_model_args(results_dict)
555
+ results_dict["model_args"] = model_args
556
+
557
+ return parsed_args_lib.ParsedArgs(**results_dict)
.venv/lib/python3.11/site-packages/google/generativeai/notebook/command.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright 2023 Google LLC
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Command."""
16
+ from __future__ import annotations
17
+
18
+ import abc
19
+ import collections
20
+ from typing import Sequence
21
+
22
+ from google.generativeai.notebook import parsed_args_lib
23
+ from google.generativeai.notebook import post_process_utils
24
+
25
+
26
+ ProcessingCommand = collections.namedtuple("ProcessingCommand", ["name", "fn"])
27
+
28
+
29
+ class Command(abc.ABC):
30
+ """Base class for implementation of Magics commands like "run"."""
31
+
32
+ @abc.abstractmethod
33
+ def execute(
34
+ self,
35
+ parsed_args: parsed_args_lib.ParsedArgs,
36
+ cell_content: str,
37
+ post_processing_fns: Sequence[post_process_utils.ParsedPostProcessExpr],
38
+ ):
39
+ """Executes the command given `parsed_args` and the `cell_content`."""
40
+
41
+ @abc.abstractmethod
42
+ def parse_post_processing_tokens(
43
+ self, tokens: Sequence[Sequence[str]]
44
+ ) -> Sequence[post_process_utils.ParsedPostProcessExpr]:
45
+ """Parses post-processing tokens for this command."""
.venv/lib/python3.11/site-packages/google/generativeai/notebook/compare_cmd.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright 2023 Google LLC
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """The compare command."""
16
+ from __future__ import annotations
17
+
18
+ from typing import Sequence
19
+
20
+ from google.generativeai.notebook import command
21
+ from google.generativeai.notebook import command_utils
22
+ from google.generativeai.notebook import input_utils
23
+ from google.generativeai.notebook import ipython_env
24
+ from google.generativeai.notebook import output_utils
25
+ from google.generativeai.notebook import parsed_args_lib
26
+ from google.generativeai.notebook import post_process_utils
27
+ import pandas
28
+
29
+
30
+ class CompareCommand(command.Command):
31
+ """Implementation of "compare" command."""
32
+
33
+ def __init__(
34
+ self,
35
+ env: ipython_env.IPythonEnv | None = None,
36
+ ):
37
+ """Constructor.
38
+
39
+ Args:
40
+ env: The IPythonEnv environment.
41
+ """
42
+ super().__init__()
43
+ self._ipython_env = env
44
+
45
+ def execute(
46
+ self,
47
+ parsed_args: parsed_args_lib.ParsedArgs,
48
+ cell_content: str,
49
+ post_processing_fns: Sequence[post_process_utils.ParsedPostProcessExpr],
50
+ ) -> pandas.DataFrame:
51
+ # We expect CmdLineParser to have already read the inputs once to validate
52
+ # that the placeholders in the prompt are present in the inputs, so we can
53
+ # suppress the status messages here.
54
+ inputs = input_utils.join_inputs_sources(parsed_args, suppress_status_msgs=True)
55
+
56
+ llm_cmp_fn = command_utils.create_llm_compare_function(
57
+ env=self._ipython_env,
58
+ parsed_args=parsed_args,
59
+ post_processing_fns=post_processing_fns,
60
+ )
61
+
62
+ results = llm_cmp_fn(inputs=inputs)
63
+ output_utils.write_to_outputs(results=results, parsed_args=parsed_args)
64
+ return results.as_pandas_dataframe()
65
+
66
+ def parse_post_processing_tokens(
67
+ self, tokens: Sequence[Sequence[str]]
68
+ ) -> Sequence[post_process_utils.ParsedPostProcessExpr]:
69
+ if tokens:
70
+ raise RuntimeError('Post-processing is not supported by "compare"')
71
+ return []
.venv/lib/python3.11/site-packages/google/generativeai/notebook/compile_cmd.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright 2023 Google LLC
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """The compile command."""
16
+ from __future__ import annotations
17
+
18
+ from typing import Sequence
19
+ from google.generativeai.notebook import command
20
+ from google.generativeai.notebook import command_utils
21
+ from google.generativeai.notebook import ipython_env
22
+ from google.generativeai.notebook import model_registry
23
+ from google.generativeai.notebook import parsed_args_lib
24
+ from google.generativeai.notebook import post_process_utils
25
+ from google.generativeai.notebook import py_utils
26
+
27
+
28
+ class CompileCommand(command.Command):
29
+ """Implementation of the "compile" command."""
30
+
31
+ def __init__(
32
+ self,
33
+ models: model_registry.ModelRegistry,
34
+ env: ipython_env.IPythonEnv | None = None,
35
+ ):
36
+ """Constructor.
37
+
38
+ Args:
39
+ models: ModelRegistry instance.
40
+ env: The IPythonEnv environment.
41
+ """
42
+ super().__init__()
43
+ self._models = models
44
+ self._ipython_env = env
45
+
46
+ def execute(
47
+ self,
48
+ parsed_args: parsed_args_lib.ParsedArgs,
49
+ cell_content: str,
50
+ post_processing_fns: Sequence[post_process_utils.ParsedPostProcessExpr],
51
+ ) -> str:
52
+ llm_fn = command_utils.create_llm_function(
53
+ models=self._models,
54
+ env=self._ipython_env,
55
+ parsed_args=parsed_args,
56
+ cell_content=cell_content,
57
+ post_processing_fns=post_processing_fns,
58
+ )
59
+
60
+ py_utils.set_py_var(parsed_args.compile_save_name, llm_fn)
61
+ return "Saved function to Python variable: {}".format(parsed_args.compile_save_name)
62
+
63
+ def parse_post_processing_tokens(
64
+ self, tokens: Sequence[Sequence[str]]
65
+ ) -> Sequence[post_process_utils.ParsedPostProcessExpr]:
66
+ return post_process_utils.resolve_post_processing_tokens(tokens)
.venv/lib/python3.11/site-packages/google/generativeai/notebook/eval_cmd.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright 2023 Google LLC
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """The eval command."""
16
+ from __future__ import annotations
17
+
18
+ from typing import Sequence
19
+
20
+ from google.generativeai.notebook import command
21
+ from google.generativeai.notebook import command_utils
22
+ from google.generativeai.notebook import input_utils
23
+ from google.generativeai.notebook import ipython_env
24
+ from google.generativeai.notebook import model_registry
25
+ from google.generativeai.notebook import output_utils
26
+ from google.generativeai.notebook import parsed_args_lib
27
+ from google.generativeai.notebook import post_process_utils
28
+ import pandas
29
+
30
+
31
+ class EvalCommand(command.Command):
32
+ """Implementation of "eval" command."""
33
+
34
+ def __init__(
35
+ self,
36
+ models: model_registry.ModelRegistry,
37
+ env: ipython_env.IPythonEnv | None = None,
38
+ ):
39
+ """Constructor.
40
+
41
+ Args:
42
+ models: ModelRegistry instance.
43
+ env: The IPythonEnv environment.
44
+ """
45
+ super().__init__()
46
+ self._models = models
47
+ self._ipython_env = env
48
+
49
+ def execute(
50
+ self,
51
+ parsed_args: parsed_args_lib.ParsedArgs,
52
+ cell_content: str,
53
+ post_processing_fns: Sequence[post_process_utils.ParsedPostProcessExpr],
54
+ ) -> pandas.DataFrame:
55
+ # We expect CmdLineParser to have already read the inputs once to validate
56
+ # that the placeholders in the prompt are present in the inputs, so we can
57
+ # suppress the status messages here.
58
+ inputs = input_utils.join_inputs_sources(parsed_args, suppress_status_msgs=True)
59
+
60
+ llm_cmp_fn = command_utils.create_llm_eval_function(
61
+ models=self._models,
62
+ env=self._ipython_env,
63
+ parsed_args=parsed_args,
64
+ cell_content=cell_content,
65
+ post_processing_fns=post_processing_fns,
66
+ )
67
+
68
+ results = llm_cmp_fn(inputs=inputs)
69
+ output_utils.write_to_outputs(results=results, parsed_args=parsed_args)
70
+ return results.as_pandas_dataframe()
71
+
72
+ def parse_post_processing_tokens(
73
+ self, tokens: Sequence[Sequence[str]]
74
+ ) -> Sequence[post_process_utils.ParsedPostProcessExpr]:
75
+ return post_process_utils.resolve_post_processing_tokens(tokens)
.venv/lib/python3.11/site-packages/google/generativeai/notebook/flag_def.py ADDED
@@ -0,0 +1,503 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright 2023 Google LLC
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Classes that define arguments for populating ArgumentParser.
16
+
17
+ The argparse module's ArgumentParser.add_argument() takes several parameters and
18
+ is quite customizable. However this can lead to bugs where arguments do not
19
+ behave as expected.
20
+
21
+ For better ease-of-use and better testability, define a set of classes for the
22
+ types of flags used by LLM Magics.
23
+
24
+ Sample usage:
25
+
26
+ str_flag = SingleValueFlagDef(name="title", required=True)
27
+ enum_flag = EnumFlagDef(name="colors", required=True, enum_type=ColorsEnum)
28
+
29
+ str_flag.add_argument_to_parser(my_parser)
30
+ enum_flag.add_argument_to_parser(my_parser)
31
+ """
32
+ from __future__ import annotations
33
+
34
+ import abc
35
+ import argparse
36
+ import dataclasses
37
+ import enum
38
+ from typing import Any, Callable, Sequence, Tuple, Union
39
+
40
+ from google.generativeai.notebook.lib import llmfn_inputs_source
41
+ from google.generativeai.notebook.lib import llmfn_outputs
42
+
43
+ # These are the intermediate types that argparse.ArgumentParser.parse_args()
44
+ # will pass command line arguments into.
45
+ _PARSETYPES = Union[str, int, float]
46
+ # These are the final result types that the intermediate parsed values will be
47
+ # converted into. It is a superset of _PARSETYPES because we support converting
48
+ # the parsed type into a more precise type, e.g. from str to Enum.
49
+ _DESTTYPES = Union[
50
+ _PARSETYPES,
51
+ enum.Enum,
52
+ Tuple[str, Callable[[str, str], Any]],
53
+ Sequence[str], # For --compare_fn
54
+ llmfn_inputs_source.LLMFnInputsSource, # For --ground_truth
55
+ llmfn_outputs.LLMFnOutputsSink, # For --inputs # For --outputs
56
+ ]
57
+
58
+ # The signature of a function that converts a command line argument from the
59
+ # intermediate parsed type to the result type.
60
+ _PARSEFN = Callable[[_PARSETYPES], _DESTTYPES]
61
+
62
+
63
+ def _get_type_name(x: type[Any]) -> str:
64
+ try:
65
+ return x.__name__
66
+ except AttributeError:
67
+ return str(x)
68
+
69
+
70
+ def _validate_flag_name(name: str) -> str:
71
+ """Validation for long and short names for flags."""
72
+ if not name:
73
+ raise ValueError("Cannot be empty")
74
+ if name[0] == "-":
75
+ raise ValueError("Cannot start with dash")
76
+ return name
77
+
78
+
79
+ @dataclasses.dataclass(frozen=True)
80
+ class FlagDef(abc.ABC):
81
+ """Abstract base class for flag definitions.
82
+
83
+ Attributes:
84
+ name: Long name, e.g. "colors" will define the flag "--colors".
85
+ required: Whether the flag must be provided on the command line.
86
+ short_name: Optional short name.
87
+ parse_type: The type that ArgumentParser should parse the command line
88
+ argument to.
89
+ dest_type: The type that the parsed value is converted to. This is used when
90
+ we want ArgumentParser to parse as one type, then convert to a different
91
+ type. E.g. for enums we parse as "str" then convert to the desired enum
92
+ type in order to provide cleaner help messages.
93
+ parse_to_dest_type_fn: If provided, this function will be used to convert
94
+ the value from `parse_type` to `dest_type`. This can be used for
95
+ validation as well.
96
+ choices: If provided, limit the set of acceptable values to these choices.
97
+ help_msg: If provided, adds help message when -h is used in the command
98
+ line.
99
+ """
100
+
101
+ name: str
102
+ required: bool = False
103
+
104
+ short_name: str | None = None
105
+
106
+ parse_type: type[_PARSETYPES] = str
107
+ dest_type: type[_DESTTYPES] | None = None
108
+ parse_to_dest_type_fn: _PARSEFN | None = None
109
+
110
+ choices: list[_PARSETYPES] | None = None
111
+ help_msg: str | None = None
112
+
113
+ @abc.abstractmethod
114
+ def add_argument_to_parser(self, parser: argparse.ArgumentParser) -> None:
115
+ """Adds this flag as an argument to `parser`.
116
+
117
+ Child classes should implement this as a call to parser.add_argument()
118
+ with the appropriate parameters.
119
+
120
+ Args:
121
+ parser: The parser to which this argument will be added.
122
+ """
123
+
124
+ @abc.abstractmethod
125
+ def _do_additional_validation(self) -> None:
126
+ """For child classes to do additional validation."""
127
+
128
+ def _get_dest_type(self) -> type[_DESTTYPES]:
129
+ """Returns the final converted type."""
130
+ return self.parse_type if self.dest_type is None else self.dest_type
131
+
132
+ def _get_parse_to_dest_type_fn(
133
+ self,
134
+ ) -> _PARSEFN:
135
+ """Returns a function to convert from parse_type to dest_type."""
136
+ if self.parse_to_dest_type_fn is not None:
137
+ return self.parse_to_dest_type_fn
138
+
139
+ dest_type = self._get_dest_type()
140
+ if dest_type == self.parse_type:
141
+ return lambda x: x
142
+ else:
143
+ return dest_type
144
+
145
+ def __post_init__(self):
146
+ _validate_flag_name(self.name)
147
+ if self.short_name is not None:
148
+ _validate_flag_name(self.short_name)
149
+
150
+ self._do_additional_validation()
151
+
152
+
153
+ def _has_non_default_value(
154
+ namespace: argparse.Namespace,
155
+ dest: str,
156
+ has_default: bool = False,
157
+ default_value: Any = None,
158
+ ) -> bool:
159
+ """Returns true if `namespace.dest` is set to a non-default value.
160
+
161
+ Args:
162
+ namespace: The Namespace that is populated by ArgumentParser.
163
+ dest: The attribute in the Namespace to be populated.
164
+ has_default: "None" is a valid default value so we use an additional
165
+ `has_default` boolean to indicate that `default_value` is present.
166
+ default_value: The default value to use when `has_default` is True.
167
+
168
+ Returns:
169
+ Whether namespace.dest is set to something other than the default value.
170
+ """
171
+ if not hasattr(namespace, dest):
172
+ return False
173
+
174
+ if not has_default:
175
+ # No default value provided so `namespace.dest` cannot possibly be equal to
176
+ # the default value.
177
+ return True
178
+
179
+ return getattr(namespace, dest) != default_value
180
+
181
+
182
+ class _SingleValueStoreAction(argparse.Action):
183
+ """Custom Action for storing a value in an argparse.Namespace.
184
+
185
+ This action checks that the flag is specified at-most once.
186
+ """
187
+
188
+ def __init__(
189
+ self,
190
+ option_strings,
191
+ dest,
192
+ dest_type: type[Any],
193
+ parse_to_dest_type_fn: _PARSEFN,
194
+ **kwargs,
195
+ ):
196
+ super().__init__(option_strings, dest, **kwargs)
197
+ self._dest_type = dest_type
198
+ self._parse_to_dest_type_fn = parse_to_dest_type_fn
199
+
200
+ def __call__(
201
+ self,
202
+ parser: argparse.ArgumentParser,
203
+ namespace: argparse.Namespace,
204
+ values: str | Sequence[Any] | None,
205
+ option_string: str | None = None,
206
+ ):
207
+ # Because `nargs` is set to 1, `values` must be a Sequence, rather
208
+ # than a string.
209
+ assert not isinstance(values, str) and not isinstance(values, bytes)
210
+
211
+ if _has_non_default_value(
212
+ namespace,
213
+ self.dest,
214
+ has_default=hasattr(self, "default"),
215
+ default_value=getattr(self, "default"),
216
+ ):
217
+ raise argparse.ArgumentError(self, "Cannot set {} more than once".format(option_string))
218
+
219
+ try:
220
+ converted_value = self._parse_to_dest_type_fn(values[0])
221
+ except Exception as e:
222
+ raise argparse.ArgumentError(
223
+ self,
224
+ 'Error with value "{}", got {}: {}'.format(values[0], _get_type_name(type(e)), e),
225
+ )
226
+
227
+ if not isinstance(converted_value, self._dest_type):
228
+ raise RuntimeError(
229
+ "Converted to wrong type, expected {} got {}".format(
230
+ _get_type_name(self._dest_type),
231
+ _get_type_name(type(converted_value)),
232
+ )
233
+ )
234
+ setattr(namespace, self.dest, converted_value)
235
+
236
+
237
+ class _MultiValuesAppendAction(argparse.Action):
238
+ """Custom Action for appending values in an argparse.Namespace.
239
+
240
+ This action checks that the flag is specified at-most once.
241
+ """
242
+
243
+ def __init__(
244
+ self,
245
+ option_strings,
246
+ dest,
247
+ dest_type: type[Any],
248
+ parse_to_dest_type_fn: _PARSEFN,
249
+ **kwargs,
250
+ ):
251
+ super().__init__(option_strings, dest, **kwargs)
252
+ self._dest_type = dest_type
253
+ self._parse_to_dest_type_fn = parse_to_dest_type_fn
254
+
255
+ def __call__(
256
+ self,
257
+ parser: argparse.ArgumentParser,
258
+ namespace: argparse.Namespace,
259
+ values: str | Sequence[Any] | None,
260
+ option_string: str | None = None,
261
+ ):
262
+ # Because `nargs` is set to "+", `values` must be a Sequence, rather
263
+ # than a string.
264
+ assert not isinstance(values, str) and not isinstance(values, bytes)
265
+
266
+ curr_value = getattr(namespace, self.dest)
267
+ if curr_value:
268
+ raise argparse.ArgumentError(self, "Cannot set {} more than once".format(option_string))
269
+
270
+ for value in values:
271
+ try:
272
+ converted_value = self._parse_to_dest_type_fn(value)
273
+ except Exception as e:
274
+ raise argparse.ArgumentError(
275
+ self,
276
+ 'Error with value "{}", got {}: {}'.format(
277
+ values[0], _get_type_name(type(e)), e
278
+ ),
279
+ )
280
+
281
+ if not isinstance(converted_value, self._dest_type):
282
+ raise RuntimeError(
283
+ "Converted to wrong type, expected {} got {}".format(
284
+ self._dest_type, type(converted_value)
285
+ )
286
+ )
287
+ if converted_value in curr_value:
288
+ raise argparse.ArgumentError(self, 'Duplicate values "{}"'.format(value))
289
+
290
+ curr_value.append(converted_value)
291
+
292
+
293
+ class _BooleanValueStoreAction(argparse.Action):
294
+ """Custom Action for setting a boolean value in argparse.Namespace.
295
+
296
+ The boolean flag expects the default to be False and will set the value to
297
+ True.
298
+ This action checks that the flag is specified at-most once.
299
+ """
300
+
301
+ def __init__(
302
+ self,
303
+ option_strings,
304
+ dest,
305
+ **kwargs,
306
+ ):
307
+ super().__init__(option_strings, dest, **kwargs)
308
+
309
+ def __call__(
310
+ self,
311
+ parser: argparse.ArgumentParser,
312
+ namespace: argparse.Namespace,
313
+ values: str | Sequence[Any] | None,
314
+ option_string: str | None = None,
315
+ ):
316
+ if _has_non_default_value(
317
+ namespace,
318
+ self.dest,
319
+ has_default=True,
320
+ default_value=False,
321
+ ):
322
+ raise argparse.ArgumentError(self, "Cannot set {} more than once".format(option_string))
323
+
324
+ setattr(namespace, self.dest, True)
325
+
326
+
327
+ @dataclasses.dataclass(frozen=True)
328
+ class SingleValueFlagDef(FlagDef):
329
+ """Definition for a flag that takes a single value.
330
+
331
+ Sample usage:
332
+ # This defines a flag that can be specified on the command line as:
333
+ # --count=10
334
+ flag = SingleValueFlagDef(name="count", parse_type=int, required=True)
335
+ flag.add_argument_to_parser(argument_parser)
336
+
337
+ Attributes:
338
+ default_value: Default value for optional flags.
339
+ """
340
+
341
+ class _DefaultValue(enum.Enum):
342
+ """Special value to represent "no value provided".
343
+
344
+ "None" can be used as a default value, so in order to differentiate between
345
+ "None" and "no value provided", create a special value for "no value
346
+ provided".
347
+ """
348
+
349
+ NOT_SET = None
350
+
351
+ default_value: _DESTTYPES | _DefaultValue | None = _DefaultValue.NOT_SET
352
+
353
+ def _has_default_value(self) -> bool:
354
+ """Returns whether `default_value` has been provided."""
355
+ return self.default_value != SingleValueFlagDef._DefaultValue.NOT_SET
356
+
357
+ def add_argument_to_parser(self, parser: argparse.ArgumentParser) -> None:
358
+ args = ["--" + self.name]
359
+ if self.short_name is not None:
360
+ args += ["-" + self.short_name]
361
+
362
+ kwargs = {}
363
+ if self._has_default_value():
364
+ kwargs["default"] = self.default_value
365
+ if self.choices is not None:
366
+ kwargs["choices"] = self.choices
367
+ if self.help_msg is not None:
368
+ kwargs["help"] = self.help_msg
369
+
370
+ parser.add_argument(
371
+ *args,
372
+ action=_SingleValueStoreAction,
373
+ type=self.parse_type,
374
+ dest_type=self._get_dest_type(),
375
+ parse_to_dest_type_fn=self._get_parse_to_dest_type_fn(),
376
+ required=self.required,
377
+ nargs=1,
378
+ **kwargs,
379
+ )
380
+
381
+ def _do_additional_validation(self) -> None:
382
+ if self.required:
383
+ if self._has_default_value():
384
+ raise ValueError("Required flags cannot have default value")
385
+ else:
386
+ if not self._has_default_value():
387
+ raise ValueError("Optional flags must have a default value")
388
+
389
+ if self._has_default_value() and self.default_value is not None:
390
+ if not isinstance(self.default_value, self._get_dest_type()):
391
+ raise ValueError("Default value must be of the same type as the destination type")
392
+
393
+
394
+ class EnumFlagDef(SingleValueFlagDef):
395
+ """Definition for a flag that takes a value from an Enum.
396
+
397
+ Sample usage:
398
+ # This defines a flag that can be specified on the command line as:
399
+ # --color=red
400
+ flag = SingleValueFlagDef(name="color", enum_type=ColorsEnum,
401
+ required=True)
402
+ flag.add_argument_to_parser(argument_parser)
403
+ """
404
+
405
+ def __init__(self, *args, enum_type: type[enum.Enum], **kwargs):
406
+ if not issubclass(enum_type, enum.Enum):
407
+ raise TypeError('"enum_type" must be of type Enum')
408
+
409
+ # These properties are set by "enum_type" so don"t let the caller set them.
410
+ if "parse_type" in kwargs:
411
+ raise ValueError('Cannot set "parse_type" for EnumFlagDef; set "enum_type" instead')
412
+ kwargs["parse_type"] = str
413
+
414
+ if "dest_type" in kwargs:
415
+ raise ValueError('Cannot set "dest_type" for EnumFlagDef; set "enum_type" instead')
416
+ kwargs["dest_type"] = enum_type
417
+
418
+ if "choices" in kwargs:
419
+ # Verify that entries in `choices` are valid enum values.
420
+ for x in kwargs["choices"]:
421
+ try:
422
+ enum_type(x)
423
+ except ValueError:
424
+ raise ValueError('Invalid value in "choices": "{}"'.format(x)) from None
425
+ else:
426
+ kwargs["choices"] = [x.value for x in enum_type]
427
+
428
+ super().__init__(*args, **kwargs)
429
+
430
+
431
+ class MultiValuesFlagDef(FlagDef):
432
+ """Definition for a flag that takes multiple values.
433
+
434
+ Sample usage:
435
+ # This defines a flag that can be specified on the command line as:
436
+ # --colors=red green blue
437
+ flag = MultiValuesFlagDef(name="colors", parse_type=str, required=True)
438
+ flag.add_argument_to_parser(argument_parser)
439
+ """
440
+
441
+ def add_argument_to_parser(self, parser: argparse.ArgumentParser) -> None:
442
+ args = ["--" + self.name]
443
+ if self.short_name is not None:
444
+ args += ["-" + self.short_name]
445
+
446
+ kwargs = {}
447
+ if self.choices is not None:
448
+ kwargs["choices"] = self.choices
449
+ if self.help_msg is not None:
450
+ kwargs["help"] = self.help_msg
451
+
452
+ parser.add_argument(
453
+ *args,
454
+ action=_MultiValuesAppendAction,
455
+ type=self.parse_type,
456
+ dest_type=self._get_dest_type(),
457
+ parse_to_dest_type_fn=self._get_parse_to_dest_type_fn(),
458
+ required=self.required,
459
+ default=[],
460
+ nargs="+",
461
+ **kwargs,
462
+ )
463
+
464
+ def _do_additional_validation(self) -> None:
465
+ # No additional validation needed.
466
+ pass
467
+
468
+
469
+ @dataclasses.dataclass(frozen=True)
470
+ class BooleanFlagDef(FlagDef):
471
+ """Definition for a Boolean flag.
472
+
473
+ A boolean flag is always optional with a default value of False. The flag does
474
+ not take any values. Specifying the flag on the commandline will set it to
475
+ True.
476
+ """
477
+
478
+ def _do_additional_validation(self) -> None:
479
+ if self.dest_type is not None:
480
+ raise ValueError("dest_type cannot be set for BooleanFlagDef")
481
+ if self.parse_to_dest_type_fn is not None:
482
+ raise ValueError("parse_to_dest_type_fn cannot be set for BooleanFlagDef")
483
+ if self.choices is not None:
484
+ raise ValueError("choices cannot be set for BooleanFlagDef")
485
+
486
+ def add_argument_to_parser(self, parser: argparse.ArgumentParser) -> None:
487
+ args = ["--" + self.name]
488
+ if self.short_name is not None:
489
+ args += ["-" + self.short_name]
490
+
491
+ kwargs = {}
492
+ if self.help_msg is not None:
493
+ kwargs["help"] = self.help_msg
494
+
495
+ parser.add_argument(
496
+ *args,
497
+ action=_BooleanValueStoreAction,
498
+ type=bool,
499
+ required=False,
500
+ default=False,
501
+ nargs=0,
502
+ **kwargs,
503
+ )
.venv/lib/python3.11/site-packages/google/generativeai/notebook/gspread_client.py ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright 2023 Google LLC
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Module that holds a global gspread.client.Client."""
16
+ from __future__ import annotations
17
+
18
+ import abc
19
+ import datetime
20
+ from typing import Any, Callable, Mapping, Sequence
21
+ from google.auth import credentials
22
+ from google.generativeai.notebook import html_utils
23
+ from google.generativeai.notebook import ipython_env
24
+ from google.generativeai.notebook import sheets_id
25
+
26
+
27
+ # The code may be running in an environment where the gspread library has not
28
+ # been installed.
29
+ _gspread_import_error: Exception | None = None
30
+ try:
31
+ # pylint: disable-next=g-import-not-at-top
32
+ import gspread
33
+ except ImportError as e:
34
+ _gspread_import_error = e
35
+ gspread = None
36
+
37
+ # Base class of exceptions that gspread.open(), open_by_url() and open_by_key()
38
+ # may throw.
39
+ GSpreadException = Exception if gspread is None else gspread.exceptions.GSpreadException # type: ignore
40
+
41
+
42
+ class SpreadsheetNotFoundError(RuntimeError):
43
+ pass
44
+
45
+
46
+ def _get_import_error() -> Exception:
47
+ return RuntimeError('"gspread" module not imported, got: {}'.format(_gspread_import_error))
48
+
49
+
50
+ class GSpreadClient(abc.ABC):
51
+ """Wrapper around gspread.client.Client.
52
+
53
+ This adds a layer of indirection for us to inject mocks for testing.
54
+ """
55
+
56
+ @abc.abstractmethod
57
+ def validate(self, sid: sheets_id.SheetsIdentifier) -> None:
58
+ """Validates that `name` is the name of a Google Sheets document.
59
+
60
+ Raises an exception if false.
61
+
62
+ Args:
63
+ sid: The identifier for the document.
64
+ """
65
+
66
+ @abc.abstractmethod
67
+ def get_all_records(
68
+ self,
69
+ sid: sheets_id.SheetsIdentifier,
70
+ worksheet_id: int,
71
+ ) -> tuple[Sequence[Mapping[str, str]], Callable[[], None]]:
72
+ """Returns all records for a Google Sheets worksheet."""
73
+
74
+ @abc.abstractmethod
75
+ def write_records(
76
+ self,
77
+ sid: sheets_id.SheetsIdentifier,
78
+ rows: Sequence[Sequence[Any]],
79
+ ) -> None:
80
+ """Writes results to a new worksheet to the Google Sheets document."""
81
+
82
+
83
+ class GSpreadClientImpl(GSpreadClient):
84
+ """Concrete implementation of GSpreadClient."""
85
+
86
+ def __init__(self, client: Any, env: ipython_env.IPythonEnv | None):
87
+ """Constructor.
88
+
89
+ Args:
90
+ client: Instance of gspread.client.Client.
91
+ env: Optional instance of IPythonEnv. This is used to display messages
92
+ such as the URL of the output Worksheet.
93
+ """
94
+ self._client = client
95
+ self._ipython_env = env
96
+
97
+ def _open(self, sid: sheets_id.SheetsIdentifier):
98
+ """Opens a Sheets document from `sid`.
99
+
100
+ Args:
101
+ sid: The identifier for the Sheets document.
102
+
103
+ Raises:
104
+ SpreadsheetNotFoundError: If the Sheets document cannot be found or
105
+ cannot be opened.
106
+
107
+ Returns:
108
+ A gspread.Worksheet instance representing the worksheet referred to by
109
+ `sid`.
110
+ """
111
+ try:
112
+ if sid.name():
113
+ return self._client.open(sid.name())
114
+ if sid.key():
115
+ return self._client.open_by_key(str(sid.key()))
116
+ if sid.url():
117
+ return self._client.open_by_url(str(sid.url()))
118
+ except GSpreadException as exc:
119
+ raise SpreadsheetNotFoundError("Unable to find Sheets with {}".format(sid)) from exc
120
+ raise SpreadsheetNotFoundError("Invalid sheets_id.SheetsIdentifier")
121
+
122
+ def validate(self, sid: sheets_id.SheetsIdentifier) -> None:
123
+ self._open(sid)
124
+
125
+ def get_all_records(
126
+ self,
127
+ sid: sheets_id.SheetsIdentifier,
128
+ worksheet_id: int,
129
+ ) -> tuple[Sequence[Mapping[str, str]], Callable[[], None]]:
130
+ sheet = self._open(sid)
131
+ worksheet = sheet.get_worksheet(worksheet_id)
132
+
133
+ if self._ipython_env is not None:
134
+ env = self._ipython_env
135
+
136
+ def _display_fn():
137
+ env.display_html(
138
+ "Reading inputs from worksheet {}".format(
139
+ html_utils.get_anchor_tag(
140
+ url=sheets_id.SheetsURL(worksheet.url),
141
+ text="{} in {}".format(worksheet.title, sheet.title),
142
+ )
143
+ )
144
+ )
145
+
146
+ else:
147
+
148
+ def _display_fn():
149
+ print("Reading inputs from worksheet {} in {}".format(worksheet.title, sheet.title))
150
+
151
+ return worksheet.get_all_records(), _display_fn
152
+
153
+ def write_records(
154
+ self,
155
+ sid: sheets_id.SheetsIdentifier,
156
+ rows: Sequence[Sequence[Any]],
157
+ ) -> None:
158
+ sheet = self._open(sid)
159
+
160
+ # Create a new Worksheet.
161
+ # `title` has to be carefully constructed: some characters like colon ":"
162
+ # will not work with gspread in Worksheet.append_rows().
163
+ current_datetime = datetime.datetime.now()
164
+ title = f"Results {current_datetime:%Y_%m_%d} ({current_datetime:%s})"
165
+
166
+ # append_rows() will resize the worksheet as needed, so `rows` and `cols`
167
+ # can be set to 1 to create a worksheet with only a single cell.
168
+ worksheet = sheet.add_worksheet(title=title, rows=1, cols=1)
169
+ worksheet.append_rows(values=rows)
170
+
171
+ if self._ipython_env is not None:
172
+ self._ipython_env.display_html(
173
+ "Results written to new worksheet {}".format(
174
+ html_utils.get_anchor_tag(
175
+ url=sheets_id.SheetsURL(worksheet.url),
176
+ text="{} in {}".format(worksheet.title, sheet.title),
177
+ )
178
+ )
179
+ )
180
+ else:
181
+ print("Results written to new worksheet {} in {}".format(worksheet.title, sheet.title))
182
+
183
+
184
+ class NullGSpreadClient(GSpreadClient):
185
+ """Null-object implementation of GSpreadClient.
186
+
187
+ This class raises an error if any of its methods are called. It is used when
188
+ the gspread library is not available.
189
+ """
190
+
191
+ def validate(self, sid: sheets_id.SheetsIdentifier) -> None:
192
+ raise _get_import_error()
193
+
194
+ def get_all_records(
195
+ self,
196
+ sid: sheets_id.SheetsIdentifier,
197
+ worksheet_id: int,
198
+ ) -> tuple[Sequence[Mapping[str, str]], Callable[[], None]]:
199
+ raise _get_import_error()
200
+
201
+ def write_records(
202
+ self,
203
+ sid: sheets_id.SheetsIdentifier,
204
+ rows: Sequence[Sequence[Any]],
205
+ ) -> None:
206
+ raise _get_import_error()
207
+
208
+
209
+ # Global instance of gspread client.
210
+ _gspread_client: GSpreadClient | None = None
211
+
212
+
213
+ def authorize(creds: credentials.Credentials, env: ipython_env.IPythonEnv | None) -> None:
214
+ """Sets up credential for gspreads."""
215
+ global _gspread_client
216
+ if gspread is not None:
217
+ client = gspread.authorize(creds) # type: ignore
218
+ _gspread_client = GSpreadClientImpl(client=client, env=env)
219
+ else:
220
+ _gspread_client = NullGSpreadClient()
221
+
222
+
223
+ def get_client() -> GSpreadClient:
224
+ if not _gspread_client:
225
+ raise RuntimeError("Must call authorize() first")
226
+ return _gspread_client
227
+
228
+
229
+ def testonly_set_client(client: GSpreadClient) -> None:
230
+ """Overrides the global client for testing."""
231
+ global _gspread_client
232
+ _gspread_client = client
.venv/lib/python3.11/site-packages/google/generativeai/notebook/html_utils.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright 2023 Google LLC
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Utilities for generating HTML."""
16
+ from __future__ import annotations
17
+
18
+ from xml.etree import ElementTree
19
+ from google.generativeai.notebook import sheets_id
20
+
21
+
22
+ def get_anchor_tag(url: sheets_id.SheetsURL, text: str) -> str:
23
+ """Returns a HTML string representing an anchor tag.
24
+
25
+ This class uses the xml.etree library to handle HTML escaping.
26
+
27
+ Args:
28
+ url: The Sheets URL to link to.
29
+ text: The text body of the link.
30
+
31
+ Returns:
32
+ A string representing a HTML fragment.
33
+ """
34
+ tag = ElementTree.Element(
35
+ "a",
36
+ attrib={
37
+ # Open in a new window/tab
38
+ "target": "_blank",
39
+ # See:
40
+ # https://developer.chrome.com/en/docs/lighthouse/best-practices/external-anchors-use-rel-noopener/
41
+ "rel": "noopener",
42
+ "href": str(url),
43
+ },
44
+ )
45
+ tag.text = text if text else "link"
46
+ return ElementTree.tostring(tag, encoding="unicode", method="html")
.venv/lib/python3.11/site-packages/google/generativeai/notebook/input_utils.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright 2023 Google LLC
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Utilities for handling input variables."""
16
+ from __future__ import annotations
17
+
18
+ from typing import Callable, Mapping
19
+
20
+ from google.generativeai.notebook import parsed_args_lib
21
+ from google.generativeai.notebook import py_utils
22
+ from google.generativeai.notebook.lib import llmfn_input_utils
23
+ from google.generativeai.notebook.lib import llmfn_inputs_source
24
+
25
+
26
+ class _NormalizedInputsSource(llmfn_inputs_source.LLMFnInputsSource):
27
+ """Wrapper around NormalizedInputsList.
28
+
29
+ By design LLMFunction does not take NormalizedInputsList as input because
30
+ NormalizedInputsList is an internal representation so we want to minimize
31
+ exposure to the caller.
32
+
33
+ When we have inputs already in normalized format (e.g. from
34
+ join_prompt_inputs()) we can wrap it as an LLMFnInputsSource to pass as an
35
+ input to LLMFunction.
36
+ """
37
+
38
+ def __init__(self, normalized_inputs: llmfn_inputs_source.NormalizedInputsList):
39
+ super().__init__()
40
+ self._normalized_inputs = normalized_inputs
41
+
42
+ def _to_normalized_inputs_impl(
43
+ self,
44
+ ) -> tuple[llmfn_inputs_source.NormalizedInputsList, Callable[[], None]]:
45
+ return self._normalized_inputs, lambda: None
46
+
47
+
48
+ def get_inputs_source_from_py_var(
49
+ var_name: str,
50
+ ) -> llmfn_inputs_source.LLMFnInputsSource:
51
+ data = py_utils.get_py_var(var_name)
52
+ if isinstance(data, llmfn_inputs_source.LLMFnInputsSource):
53
+ # No conversion needed.
54
+ return data
55
+ normalized_inputs = llmfn_input_utils.to_normalized_inputs(data)
56
+ return _NormalizedInputsSource(normalized_inputs)
57
+
58
+
59
+ def join_inputs_sources(
60
+ parsed_args: parsed_args_lib.ParsedArgs,
61
+ suppress_status_msgs: bool = False,
62
+ ) -> llmfn_inputs_source.LLMFnInputsSource:
63
+ """Get a single combined input source from `parsed_args."""
64
+ combined_inputs: list[Mapping[str, str]] = []
65
+ for source in parsed_args.inputs:
66
+ combined_inputs.extend(
67
+ source.to_normalized_inputs(suppress_status_msgs=suppress_status_msgs)
68
+ )
69
+ for source in parsed_args.sheets_input_names:
70
+ combined_inputs.extend(
71
+ source.to_normalized_inputs(suppress_status_msgs=suppress_status_msgs)
72
+ )
73
+ return _NormalizedInputsSource(combined_inputs)
.venv/lib/python3.11/site-packages/google/generativeai/notebook/ipython_env_impl.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright 2023 Google LLC
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """IPythonEnvImpl."""
16
+ from __future__ import annotations
17
+
18
+ from typing import Any
19
+ from google.generativeai.notebook import ipython_env
20
+ from IPython.core import display as ipython_display
21
+
22
+
23
+ class IPythonEnvImpl(ipython_env.IPythonEnv):
24
+ """Concrete implementation of IPythonEnv."""
25
+
26
+ def display(self, x: Any) -> None:
27
+ ipython_display.display(x)
28
+
29
+ def display_html(self, x: str) -> None:
30
+ ipython_display.display(ipython_display.HTML(x))
.venv/lib/python3.11/site-packages/google/generativeai/notebook/magics.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright 2023 Google LLC
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Colab Magics class.
16
+
17
+ Installs %%llm magics.
18
+ """
19
+ from __future__ import annotations
20
+
21
+ import abc
22
+
23
+ from google.auth import credentials
24
+ from google.generativeai import client as genai
25
+ from google.generativeai.notebook import gspread_client
26
+ from google.generativeai.notebook import ipython_env
27
+ from google.generativeai.notebook import ipython_env_impl
28
+ from google.generativeai.notebook import magics_engine
29
+ from google.generativeai.notebook import post_process_utils
30
+ from google.generativeai.notebook import sheets_utils
31
+
32
+ import IPython
33
+ from IPython.core import magic
34
+
35
+
36
+ # Set the UA to distinguish the magic from the client. Do this at import-time
37
+ # so that a user can still call `genai.configure()`, and both their settings
38
+ # and this are honored.
39
+ genai.USER_AGENT = "genai-py-magic"
40
+
41
+ SheetsInputs = sheets_utils.SheetsInputs
42
+ SheetsOutputs = sheets_utils.SheetsOutputs
43
+
44
+ # Decorator functions for post-processing.
45
+ post_process_add_fn = post_process_utils.post_process_add_fn
46
+ post_process_replace_fn = post_process_utils.post_process_replace_fn
47
+
48
+ # Globals.
49
+ _ipython_env: ipython_env.IPythonEnv | None = None
50
+
51
+
52
+ def _get_ipython_env() -> ipython_env.IPythonEnv:
53
+ """Lazily constructs and returns a global IPythonEnv instance."""
54
+ global _ipython_env
55
+ if _ipython_env is None:
56
+ _ipython_env = ipython_env_impl.IPythonEnvImpl()
57
+ return _ipython_env
58
+
59
+
60
+ def authorize(creds: credentials.Credentials) -> None:
61
+ """Sets up credentials.
62
+
63
+ This is used for interacting Google APIs, such as Google Sheets.
64
+
65
+ Args:
66
+ creds: The credentials that will be used (e.g. to read from Google Sheets.)
67
+ """
68
+ gspread_client.authorize(creds=creds, env=_get_ipython_env())
69
+
70
+
71
+ class AbstractMagics(abc.ABC):
72
+ """Defines interface to Magics class."""
73
+
74
+ @abc.abstractmethod
75
+ def llm(self, cell_line: str | None, cell_body: str | None):
76
+ """Perform various LLM-related operations.
77
+
78
+ Args:
79
+ cell_line: String to pass to the MagicsEngine.
80
+ cell_body: Contents of the cell body.
81
+ """
82
+ raise NotImplementedError()
83
+
84
+
85
+ class MagicsImpl(AbstractMagics):
86
+ """Actual class implementing the magics functionality.
87
+
88
+ We use a separate class to ensure a single, global instance
89
+ of the magics class.
90
+ """
91
+
92
+ def __init__(self):
93
+ self._engine = magics_engine.MagicsEngine(env=_get_ipython_env())
94
+
95
+ def llm(self, cell_line: str | None, cell_body: str | None):
96
+ """Perform various LLM-related operations.
97
+
98
+ Args:
99
+ cell_line: String to pass to the MagicsEngine.
100
+ cell_body: Contents of the cell body.
101
+
102
+ Returns:
103
+ Results from running MagicsEngine.
104
+ """
105
+ cell_line = cell_line or ""
106
+ cell_body = cell_body or ""
107
+ return self._engine.execute_cell(cell_line, cell_body)
108
+
109
+
110
+ @magic.magics_class
111
+ class Magics(magic.Magics):
112
+ """Class to register the magic with Colab.
113
+
114
+ Objects of this class delegate all calls to a single,
115
+ global instance.
116
+ """
117
+
118
+ # Global instance
119
+ _instance = None
120
+
121
+ @classmethod
122
+ def get_instance(cls) -> AbstractMagics:
123
+ """Retrieve global instance of the Magics object."""
124
+ if cls._instance is None:
125
+ cls._instance = MagicsImpl()
126
+ return cls._instance
127
+
128
+ @magic.line_cell_magic
129
+ def llm(self, cell_line: str | None, cell_body: str | None):
130
+ """Perform various LLM-related operations.
131
+
132
+ Args:
133
+ cell_line: String to pass to the MagicsEngine.
134
+ cell_body: Contents of the cell body.
135
+
136
+ Returns:
137
+ Results from running MagicsEngine.
138
+ """
139
+ return Magics.get_instance().llm(cell_line=cell_line, cell_body=cell_body)
140
+
141
+ @magic.line_cell_magic
142
+ def palm(self, cell_line: str | None, cell_body: str | None):
143
+ return self.llm(cell_line, cell_body)
144
+
145
+ @magic.line_cell_magic
146
+ def gemini(self, cell_line: str | None, cell_body: str | None):
147
+ return self.llm(cell_line, cell_body)
148
+
149
+
150
+ IPython.get_ipython().register_magics(Magics)
.venv/lib/python3.11/site-packages/google/generativeai/notebook/magics_engine.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright 2023 Google LLC
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """MagicsEngine class."""
16
+ from __future__ import annotations
17
+
18
+ from typing import AbstractSet, Sequence
19
+
20
+ from google.generativeai.notebook import argument_parser
21
+ from google.generativeai.notebook import cmd_line_parser
22
+ from google.generativeai.notebook import command
23
+ from google.generativeai.notebook import compare_cmd
24
+ from google.generativeai.notebook import compile_cmd
25
+ from google.generativeai.notebook import eval_cmd
26
+ from google.generativeai.notebook import ipython_env
27
+ from google.generativeai.notebook import model_registry
28
+ from google.generativeai.notebook import parsed_args_lib
29
+ from google.generativeai.notebook import post_process_utils
30
+ from google.generativeai.notebook import run_cmd
31
+ from google.generativeai.notebook.lib import prompt_utils
32
+
33
+
34
+ class MagicsEngine:
35
+ """Implementation of functionality used by Magics.
36
+
37
+ This class provides the implementation for Magics, decoupled from the
38
+ details of integrating with Colab Magics such as registration.
39
+ """
40
+
41
+ def __init__(
42
+ self,
43
+ registry: model_registry.ModelRegistry | None = None,
44
+ env: ipython_env.IPythonEnv | None = None,
45
+ ):
46
+ self._ipython_env = env
47
+ models = registry or model_registry.ModelRegistry()
48
+ self._cmd_handlers: dict[parsed_args_lib.CommandName, command.Command] = {
49
+ parsed_args_lib.CommandName.RUN_CMD: run_cmd.RunCommand(models=models, env=env),
50
+ parsed_args_lib.CommandName.COMPILE_CMD: compile_cmd.CompileCommand(
51
+ models=models, env=env
52
+ ),
53
+ parsed_args_lib.CommandName.COMPARE_CMD: compare_cmd.CompareCommand(env=env),
54
+ parsed_args_lib.CommandName.EVAL_CMD: eval_cmd.EvalCommand(models=models, env=env),
55
+ }
56
+
57
+ def parse_line(
58
+ self,
59
+ line: str,
60
+ placeholders: AbstractSet[str],
61
+ ) -> tuple[parsed_args_lib.ParsedArgs, parsed_args_lib.PostProcessingTokens]:
62
+ return cmd_line_parser.CmdLineParser().parse_line(line, placeholders)
63
+
64
+ def _get_handler(self, line: str, placeholders: AbstractSet[str]) -> tuple[
65
+ command.Command,
66
+ parsed_args_lib.ParsedArgs,
67
+ Sequence[post_process_utils.ParsedPostProcessExpr],
68
+ ]:
69
+ """Given the command line, parse and return all components.
70
+
71
+ Args:
72
+ line: The LLM Magics command line.
73
+ placeholders: Placeholders from prompts in the cell contents.
74
+
75
+ Returns:
76
+ A three-tuple containing:
77
+ - The command (e.g. "run")
78
+ - Parsed arguments for the command,
79
+ - Parsed post-processing expressions
80
+ """
81
+ parsed_args, post_processing_tokens = self.parse_line(line, placeholders)
82
+ cmd_name = parsed_args.cmd
83
+ handler = self._cmd_handlers[cmd_name]
84
+ post_processing_fns = handler.parse_post_processing_tokens(post_processing_tokens)
85
+ return handler, parsed_args, post_processing_fns
86
+
87
+ def execute_cell(self, line: str, cell_content: str):
88
+ """Executes the supplied magic line and cell payload."""
89
+ cell = _clean_cell(cell_content)
90
+ placeholders = prompt_utils.get_placeholders(cell)
91
+
92
+ try:
93
+ handler, parsed_args, post_processing_fns = self._get_handler(line, placeholders)
94
+ return handler.execute(parsed_args, cell, post_processing_fns)
95
+ except argument_parser.ParserNormalExit as e:
96
+ if self._ipython_env is not None:
97
+ e.set_ipython_env(self._ipython_env)
98
+ # ParserNormalExit implements the _ipython_display_ method so it can
99
+ # be returned as the output of this cell for display.
100
+ return e
101
+ except argument_parser.ParserError as e:
102
+ e.display(self._ipython_env)
103
+ # Raise an exception to indicate that execution for this cell has
104
+ # failed.
105
+ # The exception is re-raised as SystemExit because Colab automatically
106
+ # suppresses traceback for SystemExit but not other exceptions. Because
107
+ # ParserErrors are usually due to user error (e.g. a missing required
108
+ # flag or an invalid flag value), we want to hide the traceback to
109
+ # avoid detracting the user from the error message, and we want to
110
+ # reserve exceptions-with-traceback for actual bugs and unexpected
111
+ # errors.
112
+ error_msg = "Got parser error: {}".format(e.msgs()[-1]) if e.msgs() else ""
113
+ raise SystemExit(error_msg) from e
114
+
115
+
116
+ def _clean_cell(cell_content: str) -> str:
117
+ # Colab includes a trailing newline in cell_content. Remove only the last
118
+ # line break from cell contents (i.e. not rstrip), so that multi-line and
119
+ # intentional line breaks are preserved, but single-line prompts don't have
120
+ # a trailing line break.
121
+ cell = cell_content
122
+ if cell.endswith("\n"):
123
+ cell = cell[:-1]
124
+ return cell
.venv/lib/python3.11/site-packages/google/generativeai/notebook/model_registry.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright 2023 Google LLC
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Maintains set of LLM models that can be instantiated by name."""
16
+ from __future__ import annotations
17
+
18
+ import enum
19
+ from typing import Callable
20
+
21
+ from google.generativeai.notebook import text_model
22
+ from google.generativeai.notebook.lib import model as model_lib
23
+
24
+
25
+ class ModelName(enum.Enum):
26
+ ECHO_MODEL = "echo"
27
+ TEXT_MODEL = "text"
28
+
29
+
30
+ class ModelRegistry:
31
+ """Registry that instantiates and caches models."""
32
+
33
+ DEFAULT_MODEL = ModelName.TEXT_MODEL
34
+
35
+ def __init__(self):
36
+ self._model_cache: dict[ModelName, model_lib.AbstractModel] = {}
37
+ self._model_constructors: dict[ModelName, Callable[[], model_lib.AbstractModel]] = {
38
+ ModelName.ECHO_MODEL: model_lib.EchoModel,
39
+ ModelName.TEXT_MODEL: text_model.TextModel,
40
+ }
41
+
42
+ def get_model(self, model_name: ModelName) -> model_lib.AbstractModel:
43
+ """Given `model_name`, return the corresponding Model instance.
44
+
45
+ Model instances are cached and reused for the same `model_name`.
46
+
47
+ Args:
48
+ model_name: The name of the model.
49
+
50
+ Returns:
51
+ The corresponding model instance for `model_name`.
52
+ """
53
+ if model_name not in self._model_cache:
54
+ self._model_cache[model_name] = self._model_constructors[model_name]()
55
+ return self._model_cache[model_name]
.venv/lib/python3.11/site-packages/google/generativeai/notebook/parsed_args_lib.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright 2023 Google LLC
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Results from parsing the commandline.
16
+
17
+ This module separates the results from commandline parsing from the parser
18
+ itself so that classes that operate on the results (e.g. the subclasses of
19
+ Command) do not have to depend on the commandline parser as well.
20
+ """
21
+ from __future__ import annotations
22
+
23
+ import dataclasses
24
+ import enum
25
+ from typing import Any, Callable, Sequence
26
+
27
+ from google.generativeai.notebook import model_registry
28
+ from google.generativeai.notebook.lib import llm_function
29
+ from google.generativeai.notebook.lib import llmfn_inputs_source
30
+ from google.generativeai.notebook.lib import llmfn_outputs
31
+ from google.generativeai.notebook.lib import model as model_lib
32
+
33
+ # Post processing tokens are represented as a sequence of sequence of tokens,
34
+ # because the pipe operator could be used more than once.
35
+ PostProcessingTokens = Sequence[Sequence[str]]
36
+
37
+ # The type of function taken by the "compare_fn" flag.
38
+ # It takes the text_results of the left- and right-hand side functions as
39
+ # inputs and returns a comparison result.
40
+ TextResultCompareFn = Callable[[str, str], Any]
41
+
42
+
43
+ class CommandName(enum.Enum):
44
+ RUN_CMD = "run"
45
+ COMPILE_CMD = "compile"
46
+ COMPARE_CMD = "compare"
47
+ EVAL_CMD = "eval"
48
+
49
+
50
+ @dataclasses.dataclass(frozen=True)
51
+ class ParsedArgs:
52
+ """The results of parsing the command line."""
53
+
54
+ cmd: CommandName
55
+
56
+ # For run, compile and eval commands.
57
+ model_args: model_lib.ModelArguments
58
+ model_type: model_registry.ModelName | None = None
59
+ unique: bool = False
60
+
61
+ # For run, compare and eval commands.
62
+ inputs: Sequence[llmfn_inputs_source.LLMFnInputsSource] = dataclasses.field(
63
+ default_factory=list
64
+ )
65
+ sheets_input_names: Sequence[llmfn_inputs_source.LLMFnInputsSource] = dataclasses.field(
66
+ default_factory=list
67
+ )
68
+
69
+ outputs: Sequence[llmfn_outputs.LLMFnOutputsSink] = dataclasses.field(default_factory=list)
70
+ sheets_output_names: Sequence[llmfn_outputs.LLMFnOutputsSink] = dataclasses.field(
71
+ default_factory=list
72
+ )
73
+
74
+ # For compile command.
75
+ compile_save_name: str | None = None
76
+
77
+ # For compare command.
78
+ lhs_name_and_fn: tuple[str, llm_function.LLMFunction] | None = None
79
+ rhs_name_and_fn: tuple[str, llm_function.LLMFunction] | None = None
80
+
81
+ # For compare and eval commands.
82
+ compare_fn: Sequence[tuple[str, TextResultCompareFn]] = dataclasses.field(default_factory=list)
83
+
84
+ # For eval command.
85
+ ground_truth: Sequence[str] = dataclasses.field(default_factory=list)
.venv/lib/python3.11/site-packages/google/generativeai/notebook/post_process_utils_test_helper.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright 2023 Google LLC
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Helper module for post_process_utils_test."""
16
+ from __future__ import annotations
17
+
18
+ from google.generativeai.notebook import post_process_utils
19
+
20
+
21
+ def add_length(x: str) -> int:
22
+ return len(x)
23
+
24
+
25
+ @post_process_utils.post_process_add_fn
26
+ def add_length_decorated(x: str) -> int:
27
+ return len(x)
28
+
29
+
30
+ @post_process_utils.post_process_replace_fn
31
+ def to_upper(x: str) -> str:
32
+ return x.upper()
.venv/lib/python3.11/site-packages/google/generativeai/notebook/run_cmd.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright 2023 Google LLC
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """The run command."""
16
+ from __future__ import annotations
17
+
18
+ from typing import Sequence
19
+
20
+ from google.generativeai.notebook import command
21
+ from google.generativeai.notebook import command_utils
22
+ from google.generativeai.notebook import input_utils
23
+ from google.generativeai.notebook import ipython_env
24
+ from google.generativeai.notebook import model_registry
25
+ from google.generativeai.notebook import output_utils
26
+ from google.generativeai.notebook import parsed_args_lib
27
+ from google.generativeai.notebook import post_process_utils
28
+ import pandas
29
+
30
+
31
+ class RunCommand(command.Command):
32
+ """Implementation of the "run" command."""
33
+
34
+ def __init__(
35
+ self,
36
+ models: model_registry.ModelRegistry,
37
+ env: ipython_env.IPythonEnv | None = None,
38
+ ):
39
+ """Constructor.
40
+
41
+ Args:
42
+ models: ModelRegistry instance.
43
+ env: The IPythonEnv environment.
44
+ """
45
+ super().__init__()
46
+ self._models = models
47
+ self._ipython_env = env
48
+
49
+ def execute(
50
+ self,
51
+ parsed_args: parsed_args_lib.ParsedArgs,
52
+ cell_content: str,
53
+ post_processing_fns: Sequence[post_process_utils.ParsedPostProcessExpr],
54
+ ) -> pandas.DataFrame:
55
+ # We expect CmdLineParser to have already read the inputs once to validate
56
+ # that the placeholders in the prompt are present in the inputs, so we can
57
+ # suppress the status messages here.
58
+ inputs = input_utils.join_inputs_sources(parsed_args, suppress_status_msgs=True)
59
+
60
+ llm_fn = command_utils.create_llm_function(
61
+ models=self._models,
62
+ env=self._ipython_env,
63
+ parsed_args=parsed_args,
64
+ cell_content=cell_content,
65
+ post_processing_fns=post_processing_fns,
66
+ )
67
+
68
+ results = llm_fn(inputs=inputs)
69
+ output_utils.write_to_outputs(results=results, parsed_args=parsed_args)
70
+ return results.as_pandas_dataframe()
71
+
72
+ def parse_post_processing_tokens(
73
+ self, tokens: Sequence[Sequence[str]]
74
+ ) -> Sequence[post_process_utils.ParsedPostProcessExpr]:
75
+ return post_process_utils.resolve_post_processing_tokens(tokens)
.venv/lib/python3.11/site-packages/google/generativeai/notebook/sheets_id.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright 2023 Google LLC
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Module for classes related to identifying a Sheets document."""
16
+ from __future__ import annotations
17
+
18
+ import re
19
+ from google.generativeai.notebook import sheets_sanitize_url
20
+
21
+
22
+ def _sanitize_key(key: str) -> str:
23
+ if not re.fullmatch("[a-zA-Z0-9_-]+", key):
24
+ raise ValueError('"{}" is not a valid Sheets key'.format(key))
25
+ return key
26
+
27
+
28
+ class SheetsURL:
29
+ """Class that enforces safety by ensuring that URLs are sanitized."""
30
+
31
+ def __init__(self, url: str):
32
+ self._url: str = sheets_sanitize_url.sanitize_sheets_url(url)
33
+
34
+ def __str__(self) -> str:
35
+ return self._url
36
+
37
+
38
+ class SheetsKey:
39
+ """Class that enforces safety by ensuring that keys are sanitized."""
40
+
41
+ def __init__(self, key: str):
42
+ self._key: str = _sanitize_key(key)
43
+
44
+ def __str__(self) -> str:
45
+ return self._key
46
+
47
+
48
+ class SheetsIdentifier:
49
+ """Encapsulates a means to identify a Sheets document.
50
+
51
+ The gspread library provides three ways to look up a Sheets document: by name,
52
+ by url and by key. An instance of this class represents exactly one of the
53
+ methods.
54
+ """
55
+
56
+ def __init__(
57
+ self,
58
+ name: str | None = None,
59
+ key: SheetsKey | None = None,
60
+ url: SheetsURL | None = None,
61
+ ):
62
+ """Constructor.
63
+
64
+ Exactly one of the arguments should be provided.
65
+
66
+ Args:
67
+ name: The name of the Sheets document. More-than-one Sheets documents can
68
+ have the same name, so this is the least precise method of identifying
69
+ the document.
70
+ key: The key of the Sheets document
71
+ url: The url to the Sheets document
72
+
73
+ Raises:
74
+ ValueError: If the caller does not specify exactly one of name, url or
75
+ key.
76
+ """
77
+ self._name = name
78
+ self._key = key
79
+ self._url = url
80
+
81
+ # There should be exactly one.
82
+ num_inputs = int(bool(self._name)) + int(bool(self._key)) + int(bool(self._url))
83
+ if num_inputs != 1:
84
+ raise ValueError("Must set exactly one of name, key or url")
85
+
86
+ def name(self) -> str | None:
87
+ return self._name
88
+
89
+ def key(self) -> SheetsKey | None:
90
+ return self._key
91
+
92
+ def url(self) -> SheetsURL | None:
93
+ return self._url
94
+
95
+ def __str__(self):
96
+ if self._name:
97
+ return "name={}".format(self._name)
98
+ elif self._key:
99
+ return "key={}".format(self._key)
100
+ else:
101
+ return "url={}".format(self._url)
.venv/lib/python3.11/site-packages/google/generativeai/notebook/sheets_sanitize_url.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright 2023 Google LLC
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Utilities for working with URLs."""
16
+ from __future__ import annotations
17
+
18
+ import re
19
+ from urllib import parse
20
+
21
+
22
+ def _validate_url_part(part: str) -> None:
23
+ if not re.fullmatch("[a-zA-Z0-9_-]*", part):
24
+ raise ValueError('"{}" is outside the restricted character set'.format(part))
25
+
26
+
27
+ def _validate_url_query_or_fragment(part: str) -> None:
28
+ for key, values in parse.parse_qs(part).items():
29
+ _validate_url_part(key)
30
+ for value in values:
31
+ _validate_url_part(value)
32
+
33
+
34
+ def sanitize_sheets_url(url: str) -> str:
35
+ """Sanitize a Sheets URL.
36
+
37
+ Run some saftey checks to check whether `url` is a Sheets URL. This is not a
38
+ general-purpose URL sanitizer. Rather, it makes use of the fact that we know
39
+ the URL has to be for Sheets so we can make a few assumptions about (e.g. the
40
+ domain).
41
+
42
+ Args:
43
+ url: The url to sanitize.
44
+
45
+ Returns:
46
+ The sanitized url.
47
+
48
+ Raises:
49
+ ValueError: If `url` does not match the expected restrictions for a Sheets
50
+ URL.
51
+ """
52
+ parse_result = parse.urlparse(url)
53
+ if parse_result.scheme != "https":
54
+ raise ValueError(
55
+ 'Scheme for Sheets url must be "https", got "{}"'.format(parse_result.scheme)
56
+ )
57
+ if parse_result.netloc not in ("docs.google.com", "sheets.googleapis.com"):
58
+ raise ValueError(
59
+ 'Domain for Sheets url must be "docs.google.com", got "{}"'.format(parse_result.netloc)
60
+ )
61
+
62
+ # Path component.
63
+ try:
64
+ for fragment in parse_result.path.split("/"):
65
+ _validate_url_part(fragment)
66
+ except ValueError as exc:
67
+ raise ValueError('Invalid path for Sheets url, got "{}"'.format(parse_result.path)) from exc
68
+
69
+ # Params component.
70
+ if parse_result.params:
71
+ raise ValueError('Params component must be empty, got "{}"'.format(parse_result.params))
72
+
73
+ # Query component.
74
+ try:
75
+ _validate_url_query_or_fragment(parse_result.query)
76
+ except ValueError as exc:
77
+ raise ValueError(
78
+ 'Invalid query for Sheets url, got "{}"'.format(parse_result.query)
79
+ ) from exc
80
+
81
+ # Fragment component.
82
+ try:
83
+ _validate_url_query_or_fragment(parse_result.fragment)
84
+ except ValueError as exc:
85
+ raise ValueError(
86
+ 'Invalid fragment for Sheets url, got "{}"'.format(parse_result.fragment)
87
+ ) from exc
88
+
89
+ return url
.venv/lib/python3.11/site-packages/google/generativeai/notebook/sheets_utils.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright 2023 Google LLC
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """SheetsInputs."""
16
+ from __future__ import annotations
17
+
18
+ from typing import Any, Callable, Mapping, Sequence
19
+ from urllib import parse
20
+ from google.generativeai.notebook import gspread_client
21
+ from google.generativeai.notebook import sheets_id
22
+ from google.generativeai.notebook.lib import llmfn_inputs_source
23
+ from google.generativeai.notebook.lib import llmfn_outputs
24
+
25
+
26
+ def _try_sheet_id_as_url(value: str) -> sheets_id.SheetsIdentifier | None:
27
+ """Try to open a Sheets document with `value` as a URL."""
28
+ try:
29
+ parse_result = parse.urlparse(value)
30
+ except ValueError:
31
+ # If there's a URL parsing error, then it's not a URL.
32
+ return None
33
+
34
+ if parse_result.scheme:
35
+ # If it looks like a URL, try to open the document as a URL but don't fall
36
+ # back to trying as key or name since it's very unlikely that a key or name
37
+ # looks like a URL.
38
+ sid = sheets_id.SheetsIdentifier(url=sheets_id.SheetsURL(value))
39
+ gspread_client.get_client().validate(sid)
40
+ return sid
41
+
42
+ return None
43
+
44
+
45
+ def _try_sheet_id_as_key(value: str) -> sheets_id.SheetsIdentifier | None:
46
+ """Try to open a Sheets document with `value` as a key."""
47
+ try:
48
+ sid = sheets_id.SheetsIdentifier(key=sheets_id.SheetsKey(value))
49
+ except ValueError:
50
+ # `value` is not a well-formed Sheets key.
51
+ return None
52
+
53
+ try:
54
+ gspread_client.get_client().validate(sid)
55
+ except gspread_client.SpreadsheetNotFoundError:
56
+ return None
57
+ return sid
58
+
59
+
60
+ def _try_sheet_id_as_name(value: str) -> sheets_id.SheetsIdentifier | None:
61
+ """Try to open a Sheets document with `value` as a name."""
62
+ sid = sheets_id.SheetsIdentifier(name=value)
63
+ try:
64
+ gspread_client.get_client().validate(sid)
65
+ except gspread_client.SpreadsheetNotFoundError:
66
+ return None
67
+ return sid
68
+
69
+
70
+ def get_sheets_id_from_str(value: str) -> sheets_id.SheetsIdentifier:
71
+ if sid := _try_sheet_id_as_url(value):
72
+ return sid
73
+ if sid := _try_sheet_id_as_key(value):
74
+ return sid
75
+ if sid := _try_sheet_id_as_name(value):
76
+ return sid
77
+ raise RuntimeError('No Sheets found with "{}" as URL, key or name'.format(value))
78
+
79
+
80
+ class SheetsInputs(llmfn_inputs_source.LLMFnInputsSource):
81
+ """Inputs to an LLMFunction from Google Sheets."""
82
+
83
+ def __init__(self, sid: sheets_id.SheetsIdentifier, worksheet_id: int = 0):
84
+ super().__init__()
85
+ self._sid = sid
86
+ self._worksheet_id = worksheet_id
87
+
88
+ def _to_normalized_inputs_impl(
89
+ self,
90
+ ) -> tuple[Sequence[Mapping[str, str]], Callable[[], None]]:
91
+ return gspread_client.get_client().get_all_records(
92
+ sid=self._sid, worksheet_id=self._worksheet_id
93
+ )
94
+
95
+
96
+ class SheetsOutputs(llmfn_outputs.LLMFnOutputsSink):
97
+ """Writes outputs from an LLMFunction to Google Sheets."""
98
+
99
+ def __init__(self, sid: sheets_id.SheetsIdentifier):
100
+ self._sid = sid
101
+
102
+ def write_outputs(self, outputs: llmfn_outputs.LLMFnOutputsBase) -> None:
103
+ # Transpose `outputs` into a list of rows.
104
+ outputs_dict = outputs.as_dict()
105
+ outputs_rows: list[Sequence[Any]] = [list(outputs_dict.keys())]
106
+ outputs_rows.extend([list(x) for x in zip(*outputs_dict.values())])
107
+
108
+ gspread_client.get_client().write_records(
109
+ sid=self._sid,
110
+ rows=outputs_rows,
111
+ )
.venv/lib/python3.11/site-packages/google/generativeai/notebook/text_model.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright 2023 Google LLC
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Model that uses the Text service."""
16
+ from __future__ import annotations
17
+
18
+ from google.api_core import retry
19
+ import google.generativeai as genai
20
+ from google.generativeai.types import generation_types
21
+ from google.generativeai.notebook.lib import model as model_lib
22
+
23
+ _DEFAULT_MODEL = "models/gemini-1.5-flash"
24
+
25
+
26
+ class TextModel(model_lib.AbstractModel):
27
+ """Concrete model that uses the generate_content service."""
28
+
29
+ def _generate_text(
30
+ self,
31
+ prompt: str,
32
+ model: str | None = None,
33
+ temperature: float | None = None,
34
+ candidate_count: int | None = None,
35
+ ) -> generation_types.GenerateContentResponse:
36
+ gen_config = {}
37
+ if temperature is not None:
38
+ gen_config["temperature"] = temperature
39
+ if candidate_count is not None:
40
+ gen_config["candidate_count"] = candidate_count
41
+
42
+ model_name = model or _DEFAULT_MODEL
43
+ gen_model = genai.GenerativeModel(model_name=model_name)
44
+ gc = genai.types.generation_types.GenerationConfig(**gen_config)
45
+ return gen_model.generate_content(prompt, generation_config=gc)
46
+
47
+ def call_model(
48
+ self,
49
+ model_input: str,
50
+ model_args: model_lib.ModelArguments | None = None,
51
+ ) -> model_lib.ModelResults:
52
+ if model_args is None:
53
+ model_args = model_lib.ModelArguments()
54
+
55
+ # Wrap the generation function here, rather than decorate, so that it
56
+ # applies to any overridden calls too.
57
+ retryable_fn = retry.Retry(retry.if_transient_error)(self._generate_text)
58
+ response = retryable_fn(
59
+ prompt=model_input,
60
+ model=model_args.model,
61
+ temperature=model_args.temperature,
62
+ candidate_count=model_args.candidate_count,
63
+ )
64
+
65
+ text_outputs = []
66
+ for c in response.candidates:
67
+ text_outputs.append("".join(p.text for p in c.content.parts))
68
+
69
+ return model_lib.ModelResults(
70
+ model_input=model_input,
71
+ text_results=text_outputs,
72
+ )