koichi12 commited on
Commit
58a2d66
·
verified ·
1 Parent(s): 2cbac61

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .venv/lib/python3.11/site-packages/google/generativeai/__pycache__/files.cpython-311.pyc +0 -0
  2. .venv/lib/python3.11/site-packages/google/generativeai/__pycache__/generative_models.cpython-311.pyc +0 -0
  3. .venv/lib/python3.11/site-packages/google/generativeai/__pycache__/models.cpython-311.pyc +0 -0
  4. .venv/lib/python3.11/site-packages/google/generativeai/__pycache__/responder.cpython-311.pyc +0 -0
  5. .venv/lib/python3.11/site-packages/google/generativeai/__pycache__/string_utils.cpython-311.pyc +0 -0
  6. .venv/lib/python3.11/site-packages/google/generativeai/__pycache__/text.cpython-311.pyc +0 -0
  7. .venv/lib/python3.11/site-packages/google/generativeai/types/__pycache__/answer_types.cpython-311.pyc +0 -0
  8. .venv/lib/python3.11/site-packages/google/generativeai/types/__pycache__/caching_types.cpython-311.pyc +0 -0
  9. .venv/lib/python3.11/site-packages/google/generativeai/types/__pycache__/content_types.cpython-311.pyc +0 -0
  10. .venv/lib/python3.11/site-packages/google/generativeai/types/__pycache__/discuss_types.cpython-311.pyc +0 -0
  11. .venv/lib/python3.11/site-packages/google/generativeai/types/__pycache__/helper_types.cpython-311.pyc +0 -0
  12. .venv/lib/python3.11/site-packages/google/generativeai/types/__pycache__/model_types.cpython-311.pyc +0 -0
  13. .venv/lib/python3.11/site-packages/google/generativeai/types/__pycache__/retriever_types.cpython-311.pyc +0 -0
  14. .venv/lib/python3.11/site-packages/google/generativeai/types/__pycache__/safety_types.cpython-311.pyc +0 -0
  15. .venv/lib/python3.11/site-packages/google/generativeai/types/answer_types.py +58 -0
  16. .venv/lib/python3.11/site-packages/google/generativeai/types/caching_types.py +83 -0
  17. .venv/lib/python3.11/site-packages/google/generativeai/types/content_types.py +985 -0
  18. .venv/lib/python3.11/site-packages/google/generativeai/types/discuss_types.py +208 -0
  19. .venv/lib/python3.11/site-packages/google/generativeai/types/file_types.py +143 -0
  20. .venv/lib/python3.11/site-packages/google/generativeai/types/generation_types.py +759 -0
  21. .venv/lib/python3.11/site-packages/google/generativeai/types/image_types/__pycache__/__init__.cpython-311.pyc +0 -0
  22. .venv/lib/python3.11/site-packages/google/generativeai/types/image_types/__pycache__/_image_types.cpython-311.pyc +0 -0
  23. .venv/lib/python3.11/site-packages/google/generativeai/types/image_types/_image_types.py +440 -0
  24. .venv/lib/python3.11/site-packages/google/generativeai/types/model_types.py +390 -0
  25. .venv/lib/python3.11/site-packages/google/generativeai/types/palm_safety_types.py +286 -0
  26. .venv/lib/python3.11/site-packages/google/generativeai/types/safety_types.py +303 -0
  27. .venv/lib/python3.11/site-packages/google/generativeai/types/text_types.py +32 -0
  28. .venv/lib/python3.11/site-packages/google/logging/type/__pycache__/http_request_pb2.cpython-311.pyc +0 -0
  29. .venv/lib/python3.11/site-packages/google/logging/type/__pycache__/log_severity_pb2.cpython-311.pyc +0 -0
  30. .venv/lib/python3.11/site-packages/google/logging/type/http_request.proto +95 -0
  31. .venv/lib/python3.11/site-packages/google/logging/type/log_severity.proto +71 -0
  32. .venv/lib/python3.11/site-packages/google/logging/type/log_severity_pb2.py +44 -0
  33. .venv/lib/python3.11/site-packages/google/protobuf/__init__.py +10 -0
  34. .venv/lib/python3.11/site-packages/google/protobuf/any.py +39 -0
  35. .venv/lib/python3.11/site-packages/google/protobuf/any_pb2.py +37 -0
  36. .venv/lib/python3.11/site-packages/google/protobuf/api_pb2.py +43 -0
  37. .venv/lib/python3.11/site-packages/google/protobuf/descriptor.py +1511 -0
  38. .venv/lib/python3.11/site-packages/google/protobuf/descriptor_database.py +154 -0
  39. .venv/lib/python3.11/site-packages/google/protobuf/descriptor_pb2.py +0 -0
  40. .venv/lib/python3.11/site-packages/google/protobuf/descriptor_pool.py +1355 -0
  41. .venv/lib/python3.11/site-packages/google/protobuf/duration.py +100 -0
  42. .venv/lib/python3.11/site-packages/google/protobuf/duration_pb2.py +37 -0
  43. .venv/lib/python3.11/site-packages/google/protobuf/empty_pb2.py +37 -0
  44. .venv/lib/python3.11/site-packages/google/protobuf/field_mask_pb2.py +37 -0
  45. .venv/lib/python3.11/site-packages/google/protobuf/internal/_parameterized.py +420 -0
  46. .venv/lib/python3.11/site-packages/google/protobuf/internal/containers.py +677 -0
  47. .venv/lib/python3.11/site-packages/google/protobuf/internal/encoder.py +806 -0
  48. .venv/lib/python3.11/site-packages/google/protobuf/internal/python_edition_defaults.py +5 -0
  49. .venv/lib/python3.11/site-packages/google/protobuf/internal/testing_refleaks.py +119 -0
  50. .venv/lib/python3.11/site-packages/google/protobuf/internal/well_known_types.py +678 -0
.venv/lib/python3.11/site-packages/google/generativeai/__pycache__/files.cpython-311.pyc ADDED
Binary file (4.88 kB). View file
 
.venv/lib/python3.11/site-packages/google/generativeai/__pycache__/generative_models.cpython-311.pyc ADDED
Binary file (35.1 kB). View file
 
.venv/lib/python3.11/site-packages/google/generativeai/__pycache__/models.cpython-311.pyc ADDED
Binary file (17.4 kB). View file
 
.venv/lib/python3.11/site-packages/google/generativeai/__pycache__/responder.cpython-311.pyc ADDED
Binary file (24.2 kB). View file
 
.venv/lib/python3.11/site-packages/google/generativeai/__pycache__/string_utils.cpython-311.pyc ADDED
Binary file (3.78 kB). View file
 
.venv/lib/python3.11/site-packages/google/generativeai/__pycache__/text.cpython-311.pyc ADDED
Binary file (14.6 kB). View file
 
.venv/lib/python3.11/site-packages/google/generativeai/types/__pycache__/answer_types.cpython-311.pyc ADDED
Binary file (2.09 kB). View file
 
.venv/lib/python3.11/site-packages/google/generativeai/types/__pycache__/caching_types.cpython-311.pyc ADDED
Binary file (2.89 kB). View file
 
.venv/lib/python3.11/site-packages/google/generativeai/types/__pycache__/content_types.cpython-311.pyc ADDED
Binary file (43.7 kB). View file
 
.venv/lib/python3.11/site-packages/google/generativeai/types/__pycache__/discuss_types.cpython-311.pyc ADDED
Binary file (8.6 kB). View file
 
.venv/lib/python3.11/site-packages/google/generativeai/types/__pycache__/helper_types.cpython-311.pyc ADDED
Binary file (3.42 kB). View file
 
.venv/lib/python3.11/site-packages/google/generativeai/types/__pycache__/model_types.cpython-311.pyc ADDED
Binary file (18.4 kB). View file
 
.venv/lib/python3.11/site-packages/google/generativeai/types/__pycache__/retriever_types.cpython-311.pyc ADDED
Binary file (66.8 kB). View file
 
.venv/lib/python3.11/site-packages/google/generativeai/types/__pycache__/safety_types.cpython-311.pyc ADDED
Binary file (13.6 kB). View file
 
.venv/lib/python3.11/site-packages/google/generativeai/types/answer_types.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright 2023 Google LLC
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ from __future__ import annotations
16
+
17
+ from typing import Union
18
+
19
+ from google.generativeai import protos
20
+
21
+ __all__ = ["Answer"]
22
+
23
+ FinishReason = protos.Candidate.FinishReason
24
+
25
+ FinishReasonOptions = Union[int, str, FinishReason]
26
+
27
+ _FINISH_REASONS: dict[FinishReasonOptions, FinishReason] = {
28
+ FinishReason.FINISH_REASON_UNSPECIFIED: FinishReason.FINISH_REASON_UNSPECIFIED,
29
+ 0: FinishReason.FINISH_REASON_UNSPECIFIED,
30
+ "finish_reason_unspecified": FinishReason.FINISH_REASON_UNSPECIFIED,
31
+ "unspecified": FinishReason.FINISH_REASON_UNSPECIFIED,
32
+ FinishReason.STOP: FinishReason.STOP,
33
+ 1: FinishReason.STOP,
34
+ "finish_reason_stop": FinishReason.STOP,
35
+ "stop": FinishReason.STOP,
36
+ FinishReason.MAX_TOKENS: FinishReason.MAX_TOKENS,
37
+ 2: FinishReason.MAX_TOKENS,
38
+ "finish_reason_max_tokens": FinishReason.MAX_TOKENS,
39
+ "max_tokens": FinishReason.MAX_TOKENS,
40
+ FinishReason.SAFETY: FinishReason.SAFETY,
41
+ 3: FinishReason.SAFETY,
42
+ "finish_reason_safety": FinishReason.SAFETY,
43
+ "safety": FinishReason.SAFETY,
44
+ FinishReason.RECITATION: FinishReason.RECITATION,
45
+ 4: FinishReason.RECITATION,
46
+ "finish_reason_recitation": FinishReason.RECITATION,
47
+ "recitation": FinishReason.RECITATION,
48
+ FinishReason.OTHER: FinishReason.OTHER,
49
+ 5: FinishReason.OTHER,
50
+ "finish_reason_other": FinishReason.OTHER,
51
+ "other": FinishReason.OTHER,
52
+ }
53
+
54
+
55
+ def to_finish_reason(x: FinishReasonOptions) -> FinishReason:
56
+ if isinstance(x, str):
57
+ x = x.lower()
58
+ return _FINISH_REASONS[x]
.venv/lib/python3.11/site-packages/google/generativeai/types/caching_types.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright 2024 Google LLC
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ from __future__ import annotations
16
+
17
+ import datetime
18
+ from typing import Union
19
+ from typing_extensions import TypedDict
20
+
21
+ __all__ = [
22
+ "ExpireTime",
23
+ "TTL",
24
+ "TTLTypes",
25
+ "ExpireTimeTypes",
26
+ ]
27
+
28
+
29
+ class TTL(TypedDict):
30
+ # Represents datetime.datetime.now() + desired ttl
31
+ seconds: int
32
+ nanos: int
33
+
34
+
35
+ class ExpireTime(TypedDict):
36
+ # Represents seconds of UTC time since Unix epoch
37
+ seconds: int
38
+ nanos: int
39
+
40
+
41
+ TTLTypes = Union[TTL, int, datetime.timedelta]
42
+ ExpireTimeTypes = Union[ExpireTime, int, datetime.datetime]
43
+
44
+
45
+ def to_optional_ttl(ttl: TTLTypes | None) -> TTL | None:
46
+ if ttl is None:
47
+ return None
48
+ elif isinstance(ttl, datetime.timedelta):
49
+ return {
50
+ "seconds": int(ttl.total_seconds()),
51
+ "nanos": int(ttl.microseconds * 1000),
52
+ }
53
+ elif isinstance(ttl, dict):
54
+ return ttl
55
+ elif isinstance(ttl, int):
56
+ return {"seconds": ttl, "nanos": 0}
57
+ else:
58
+ raise TypeError(
59
+ f"Could not convert input to `ttl` \n'" f" type: {type(ttl)}\n",
60
+ ttl,
61
+ )
62
+
63
+
64
+ def to_optional_expire_time(expire_time: ExpireTimeTypes | None) -> ExpireTime | None:
65
+ if expire_time is None:
66
+ return expire_time
67
+ elif isinstance(expire_time, datetime.datetime):
68
+ timestamp = expire_time.timestamp()
69
+ seconds = int(timestamp)
70
+ nanos = int((seconds % 1) * 1000)
71
+ return {
72
+ "seconds": seconds,
73
+ "nanos": nanos,
74
+ }
75
+ elif isinstance(expire_time, dict):
76
+ return expire_time
77
+ elif isinstance(expire_time, int):
78
+ return {"seconds": expire_time, "nanos": 0}
79
+ else:
80
+ raise TypeError(
81
+ f"Could not convert input to `expire_time` \n'" f" type: {type(expire_time)}\n",
82
+ expire_time,
83
+ )
.venv/lib/python3.11/site-packages/google/generativeai/types/content_types.py ADDED
@@ -0,0 +1,985 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from __future__ import annotations
17
+
18
+ from collections.abc import Iterable, Mapping, Sequence
19
+ import io
20
+ import inspect
21
+ import mimetypes
22
+ import pathlib
23
+ import typing
24
+ from typing import Any, Callable, Union
25
+ from typing_extensions import TypedDict
26
+
27
+ import pydantic
28
+
29
+ from google.generativeai.types import file_types
30
+ from google.generativeai import protos
31
+
32
+ if typing.TYPE_CHECKING:
33
+ import PIL.Image
34
+ import PIL.ImageFile
35
+ import IPython.display
36
+
37
+ IMAGE_TYPES = (PIL.Image.Image, IPython.display.Image)
38
+ else:
39
+ IMAGE_TYPES = ()
40
+ try:
41
+ import PIL.Image
42
+ import PIL.ImageFile
43
+
44
+ IMAGE_TYPES = IMAGE_TYPES + (PIL.Image.Image,)
45
+ except ImportError:
46
+ PIL = None
47
+
48
+ try:
49
+ import IPython.display
50
+
51
+ IMAGE_TYPES = IMAGE_TYPES + (IPython.display.Image,)
52
+ except ImportError:
53
+ IPython = None
54
+
55
+
56
+ __all__ = [
57
+ "BlobDict",
58
+ "BlobType",
59
+ "PartDict",
60
+ "PartType",
61
+ "ContentDict",
62
+ "ContentType",
63
+ "StrictContentType",
64
+ "ContentsType",
65
+ "FunctionDeclaration",
66
+ "CallableFunctionDeclaration",
67
+ "FunctionDeclarationType",
68
+ "Tool",
69
+ "ToolDict",
70
+ "ToolsType",
71
+ "FunctionLibrary",
72
+ "FunctionLibraryType",
73
+ ]
74
+
75
+ Mode = protos.DynamicRetrievalConfig.Mode
76
+
77
+ ModeOptions = Union[int, str, Mode]
78
+
79
+ _MODE: dict[ModeOptions, Mode] = {
80
+ Mode.MODE_UNSPECIFIED: Mode.MODE_UNSPECIFIED,
81
+ 0: Mode.MODE_UNSPECIFIED,
82
+ "mode_unspecified": Mode.MODE_UNSPECIFIED,
83
+ "unspecified": Mode.MODE_UNSPECIFIED,
84
+ Mode.MODE_DYNAMIC: Mode.MODE_DYNAMIC,
85
+ 1: Mode.MODE_DYNAMIC,
86
+ "mode_dynamic": Mode.MODE_DYNAMIC,
87
+ "dynamic": Mode.MODE_DYNAMIC,
88
+ }
89
+
90
+
91
+ def to_mode(x: ModeOptions) -> Mode:
92
+ if isinstance(x, str):
93
+ x = x.lower()
94
+ return _MODE[x]
95
+
96
+
97
+ def _pil_to_blob(image: PIL.Image.Image) -> protos.Blob:
98
+ # If the image is a local file, return a file-based blob without any modification.
99
+ # Otherwise, return a lossless WebP blob (same quality with optimized size).
100
+ def file_blob(image: PIL.Image.Image) -> protos.Blob | None:
101
+ if not isinstance(image, PIL.ImageFile.ImageFile) or image.filename is None:
102
+ return None
103
+ filename = str(image.filename)
104
+ if not pathlib.Path(filename).is_file():
105
+ return None
106
+
107
+ mime_type = image.get_format_mimetype()
108
+ image_bytes = pathlib.Path(filename).read_bytes()
109
+
110
+ return protos.Blob(mime_type=mime_type, data=image_bytes)
111
+
112
+ def webp_blob(image: PIL.Image.Image) -> protos.Blob:
113
+ # Reference: https://pillow.readthedocs.io/en/stable/handbook/image-file-formats.html#webp
114
+ image_io = io.BytesIO()
115
+ image.save(image_io, format="webp", lossless=True)
116
+ image_io.seek(0)
117
+
118
+ mime_type = "image/webp"
119
+ image_bytes = image_io.read()
120
+
121
+ return protos.Blob(mime_type=mime_type, data=image_bytes)
122
+
123
+ return file_blob(image) or webp_blob(image)
124
+
125
+
126
+ def image_to_blob(image) -> protos.Blob:
127
+ if PIL is not None:
128
+ if isinstance(image, PIL.Image.Image):
129
+ return _pil_to_blob(image)
130
+
131
+ if IPython is not None:
132
+ if isinstance(image, IPython.display.Image):
133
+ name = image.filename
134
+ if name is None:
135
+ raise ValueError(
136
+ "Conversion failed. The `IPython.display.Image` can only be converted if "
137
+ "it is constructed from a local file. Please ensure you are using the format: Image(filename='...')."
138
+ )
139
+ mime_type, _ = mimetypes.guess_type(name)
140
+ if mime_type is None:
141
+ mime_type = "image/unknown"
142
+
143
+ return protos.Blob(mime_type=mime_type, data=image.data)
144
+
145
+ raise TypeError(
146
+ "Image conversion failed. The input was expected to be of type `Image` "
147
+ "(either `PIL.Image.Image` or `IPython.display.Image`).\n"
148
+ f"However, received an object of type: {type(image)}.\n"
149
+ f"Object Value: {image}"
150
+ )
151
+
152
+
153
+ class BlobDict(TypedDict):
154
+ mime_type: str
155
+ data: bytes
156
+
157
+
158
+ def _convert_dict(d: Mapping) -> protos.Content | protos.Part | protos.Blob:
159
+ if is_content_dict(d):
160
+ content = dict(d)
161
+ if isinstance(parts := content["parts"], str):
162
+ content["parts"] = [parts]
163
+ content["parts"] = [to_part(part) for part in content["parts"]]
164
+ return protos.Content(content)
165
+ elif is_part_dict(d):
166
+ part = dict(d)
167
+ if "inline_data" in part:
168
+ part["inline_data"] = to_blob(part["inline_data"])
169
+ if "file_data" in part:
170
+ part["file_data"] = file_types.to_file_data(part["file_data"])
171
+ return protos.Part(part)
172
+ elif is_blob_dict(d):
173
+ blob = d
174
+ return protos.Blob(blob)
175
+ else:
176
+ raise KeyError(
177
+ "Unable to determine the intended type of the `dict`. "
178
+ "For `Content`, a 'parts' key is expected. "
179
+ "For `Part`, either an 'inline_data' or a 'text' key is expected. "
180
+ "For `Blob`, both 'mime_type' and 'data' keys are expected. "
181
+ f"However, the provided dictionary has the following keys: {list(d.keys())}"
182
+ )
183
+
184
+
185
+ def is_blob_dict(d):
186
+ return "mime_type" in d and "data" in d
187
+
188
+
189
+ if typing.TYPE_CHECKING:
190
+ BlobType = Union[
191
+ protos.Blob, BlobDict, PIL.Image.Image, IPython.display.Image
192
+ ] # Any for the images
193
+ else:
194
+ BlobType = Union[protos.Blob, BlobDict, Any]
195
+
196
+
197
+ def to_blob(blob: BlobType) -> protos.Blob:
198
+ if isinstance(blob, Mapping):
199
+ blob = _convert_dict(blob)
200
+
201
+ if isinstance(blob, protos.Blob):
202
+ return blob
203
+ elif isinstance(blob, IMAGE_TYPES):
204
+ return image_to_blob(blob)
205
+ else:
206
+ if isinstance(blob, Mapping):
207
+ raise KeyError(
208
+ "Could not recognize the intended type of the `dict`\n" "A content should have "
209
+ )
210
+ raise TypeError(
211
+ "Could not create `Blob`, expected `Blob`, `dict` or an `Image` type"
212
+ "(`PIL.Image.Image` or `IPython.display.Image`).\n"
213
+ f"Got a: {type(blob)}\n"
214
+ f"Value: {blob}"
215
+ )
216
+
217
+
218
+ class PartDict(TypedDict):
219
+ text: str
220
+ inline_data: BlobType
221
+
222
+
223
+ # When you need a `Part` accept a part object, part-dict, blob or string
224
+ PartType = Union[
225
+ protos.Part,
226
+ PartDict,
227
+ BlobType,
228
+ str,
229
+ protos.FunctionCall,
230
+ protos.FunctionResponse,
231
+ file_types.FileDataType,
232
+ ]
233
+
234
+
235
+ def is_part_dict(d):
236
+ keys = list(d.keys())
237
+ if len(keys) != 1:
238
+ return False
239
+
240
+ key = keys[0]
241
+
242
+ return key in ["text", "inline_data", "function_call", "function_response", "file_data"]
243
+
244
+
245
+ def to_part(part: PartType):
246
+ if isinstance(part, Mapping):
247
+ part = _convert_dict(part)
248
+
249
+ if isinstance(part, protos.Part):
250
+ return part
251
+ elif isinstance(part, str):
252
+ return protos.Part(text=part)
253
+ elif isinstance(part, protos.FileData):
254
+ return protos.Part(file_data=part)
255
+ elif isinstance(part, (protos.File, file_types.File)):
256
+ return protos.Part(file_data=file_types.to_file_data(part))
257
+ elif isinstance(part, protos.FunctionCall):
258
+ return protos.Part(function_call=part)
259
+ elif isinstance(part, protos.FunctionResponse):
260
+ return protos.Part(function_response=part)
261
+
262
+ else:
263
+ # Maybe it can be turned into a blob?
264
+ return protos.Part(inline_data=to_blob(part))
265
+
266
+
267
+ class ContentDict(TypedDict):
268
+ parts: list[PartType]
269
+ role: str
270
+
271
+
272
+ def is_content_dict(d):
273
+ return "parts" in d
274
+
275
+
276
+ # When you need a message accept a `Content` object or dict, a list of parts,
277
+ # or a single part
278
+ ContentType = Union[protos.Content, ContentDict, Iterable[PartType], PartType]
279
+
280
+ # For generate_content, we're not guessing roles for [[parts],[parts],[parts]] yet.
281
+ StrictContentType = Union[protos.Content, ContentDict]
282
+
283
+
284
+ def to_content(content: ContentType):
285
+ if not content:
286
+ raise ValueError(
287
+ "Invalid input: 'content' argument must not be empty. Please provide a non-empty value."
288
+ )
289
+
290
+ if isinstance(content, Mapping):
291
+ content = _convert_dict(content)
292
+
293
+ if isinstance(content, protos.Content):
294
+ return content
295
+ elif isinstance(content, Iterable) and not isinstance(content, str):
296
+ return protos.Content(parts=[to_part(part) for part in content])
297
+ else:
298
+ # Maybe this is a Part?
299
+ return protos.Content(parts=[to_part(content)])
300
+
301
+
302
+ def strict_to_content(content: StrictContentType):
303
+ if isinstance(content, Mapping):
304
+ content = _convert_dict(content)
305
+
306
+ if isinstance(content, protos.Content):
307
+ return content
308
+ else:
309
+ raise TypeError(
310
+ "Invalid input type. Expected a `protos.Content` or a `dict` with a 'parts' key.\n"
311
+ f"However, received an object of type: {type(content)}.\n"
312
+ f"Object Value: {content}"
313
+ )
314
+
315
+
316
+ ContentsType = Union[ContentType, Iterable[StrictContentType], None]
317
+
318
+
319
+ def to_contents(contents: ContentsType) -> list[protos.Content]:
320
+ if contents is None:
321
+ return []
322
+
323
+ if isinstance(contents, Iterable) and not isinstance(contents, (str, Mapping)):
324
+ try:
325
+ # strict_to_content so [[parts], [parts]] doesn't assume roles.
326
+ contents = [strict_to_content(c) for c in contents]
327
+ return contents
328
+ except TypeError:
329
+ # If you get a TypeError here it's probably because that was a list
330
+ # of parts, not a list of contents, so fall back to `to_content`.
331
+ pass
332
+
333
+ contents = [to_content(contents)]
334
+ return contents
335
+
336
+
337
+ def _schema_for_class(cls: TypedDict) -> dict[str, Any]:
338
+ schema = _build_schema("dummy", {"dummy": (cls, pydantic.Field())})
339
+ return schema["properties"]["dummy"]
340
+
341
+
342
+ def _schema_for_function(
343
+ f: Callable[..., Any],
344
+ *,
345
+ descriptions: Mapping[str, str] | None = None,
346
+ required: Sequence[str] | None = None,
347
+ ) -> dict[str, Any]:
348
+ """Generates the OpenAPI Schema for a python function.
349
+
350
+ Args:
351
+ f: The function to generate an OpenAPI Schema for.
352
+ descriptions: Optional. A `{name: description}` mapping for annotating input
353
+ arguments of the function with user-provided descriptions. It
354
+ defaults to an empty dictionary (i.e. there will not be any
355
+ description for any of the inputs).
356
+ required: Optional. For the user to specify the set of required arguments in
357
+ function calls to `f`. If unspecified, it will be automatically
358
+ inferred from `f`.
359
+
360
+ Returns:
361
+ dict[str, Any]: The OpenAPI Schema for the function `f` in JSON format.
362
+ """
363
+ if descriptions is None:
364
+ descriptions = {}
365
+ defaults = dict(inspect.signature(f).parameters)
366
+
367
+ fields_dict = {}
368
+ for name, param in defaults.items():
369
+ if param.kind in (
370
+ inspect.Parameter.POSITIONAL_OR_KEYWORD,
371
+ inspect.Parameter.KEYWORD_ONLY,
372
+ inspect.Parameter.POSITIONAL_ONLY,
373
+ ):
374
+ # We do not support default values for now.
375
+ # default=(
376
+ # param.default if param.default != inspect.Parameter.empty
377
+ # else None
378
+ # ),
379
+ field = pydantic.Field(
380
+ # We support user-provided descriptions.
381
+ description=descriptions.get(name, None)
382
+ )
383
+
384
+ # 1. We infer the argument type here: use Any rather than None so
385
+ # it will not try to auto-infer the type based on the default value.
386
+ if param.annotation != inspect.Parameter.empty:
387
+ fields_dict[name] = param.annotation, field
388
+ else:
389
+ fields_dict[name] = Any, field
390
+
391
+ parameters = _build_schema(f.__name__, fields_dict)
392
+
393
+ # 6. Annotate required fields.
394
+ if required is not None:
395
+ # We use the user-provided "required" fields if specified.
396
+ parameters["required"] = required
397
+ else:
398
+ # Otherwise we infer it from the function signature.
399
+ parameters["required"] = [
400
+ k
401
+ for k in defaults
402
+ if (
403
+ defaults[k].default == inspect.Parameter.empty
404
+ and defaults[k].kind
405
+ in (
406
+ inspect.Parameter.POSITIONAL_OR_KEYWORD,
407
+ inspect.Parameter.KEYWORD_ONLY,
408
+ inspect.Parameter.POSITIONAL_ONLY,
409
+ )
410
+ )
411
+ ]
412
+ schema = dict(name=f.__name__, description=f.__doc__)
413
+ if parameters["properties"]:
414
+ schema["parameters"] = parameters
415
+
416
+ return schema
417
+
418
+
419
+ def _build_schema(fname, fields_dict):
420
+ parameters = pydantic.create_model(fname, **fields_dict).model_json_schema()
421
+ defs = parameters.pop("$defs", {})
422
+ # flatten the defs
423
+ for name, value in defs.items():
424
+ unpack_defs(value, defs)
425
+ unpack_defs(parameters, defs)
426
+
427
+ # 5. Nullable fields:
428
+ # * https://github.com/pydantic/pydantic/issues/1270
429
+ # * https://stackoverflow.com/a/58841311
430
+ # * https://github.com/pydantic/pydantic/discussions/4872
431
+ convert_to_nullable(parameters)
432
+ add_object_type(parameters)
433
+ # Postprocessing
434
+ # 4. Suppress unnecessary title generation:
435
+ # * https://github.com/pydantic/pydantic/issues/1051
436
+ # * http://cl/586221780
437
+ strip_titles(parameters)
438
+ return parameters
439
+
440
+
441
+ def unpack_defs(schema, defs):
442
+ properties = schema.get("properties", None)
443
+ if properties is None:
444
+ return
445
+
446
+ for name, value in properties.items():
447
+ ref_key = value.get("$ref", None)
448
+ if ref_key is not None:
449
+ ref = defs[ref_key.split("defs/")[-1]]
450
+ unpack_defs(ref, defs)
451
+ properties[name] = ref
452
+ continue
453
+
454
+ anyof = value.get("anyOf", None)
455
+ if anyof is not None:
456
+ for i, atype in enumerate(anyof):
457
+ ref_key = atype.get("$ref", None)
458
+ if ref_key is not None:
459
+ ref = defs[ref_key.split("defs/")[-1]]
460
+ unpack_defs(ref, defs)
461
+ anyof[i] = ref
462
+ continue
463
+
464
+ items = value.get("items", None)
465
+ if items is not None:
466
+ ref_key = items.get("$ref", None)
467
+ if ref_key is not None:
468
+ ref = defs[ref_key.split("defs/")[-1]]
469
+ unpack_defs(ref, defs)
470
+ value["items"] = ref
471
+ continue
472
+
473
+
474
+ def strip_titles(schema):
475
+ title = schema.pop("title", None)
476
+
477
+ properties = schema.get("properties", None)
478
+ if properties is not None:
479
+ for name, value in properties.items():
480
+ strip_titles(value)
481
+
482
+ items = schema.get("items", None)
483
+ if items is not None:
484
+ strip_titles(items)
485
+
486
+
487
+ def add_object_type(schema):
488
+ properties = schema.get("properties", None)
489
+ if properties is not None:
490
+ schema.pop("required", None)
491
+ schema["type"] = "object"
492
+ for name, value in properties.items():
493
+ add_object_type(value)
494
+
495
+ items = schema.get("items", None)
496
+ if items is not None:
497
+ add_object_type(items)
498
+
499
+
500
+ def convert_to_nullable(schema):
501
+ anyof = schema.pop("anyOf", None)
502
+ if anyof is not None:
503
+ if len(anyof) != 2:
504
+ raise ValueError(
505
+ "Invalid input: Type Unions are not supported, except for `Optional` types. "
506
+ "Please provide an `Optional` type or a non-Union type."
507
+ )
508
+ a, b = anyof
509
+ if a == {"type": "null"}:
510
+ schema.update(b)
511
+ elif b == {"type": "null"}:
512
+ schema.update(a)
513
+ else:
514
+ raise ValueError(
515
+ "Invalid input: Type Unions are not supported, except for `Optional` types. "
516
+ "Please provide an `Optional` type or a non-Union type."
517
+ )
518
+ schema["nullable"] = True
519
+
520
+ properties = schema.get("properties", None)
521
+ if properties is not None:
522
+ for name, value in properties.items():
523
+ convert_to_nullable(value)
524
+
525
+ items = schema.get("items", None)
526
+ if items is not None:
527
+ convert_to_nullable(items)
528
+
529
+
530
+ def _rename_schema_fields(schema):
531
+ if schema is None:
532
+ return schema
533
+
534
+ schema = schema.copy()
535
+
536
+ type_ = schema.pop("type", None)
537
+ if type_ is not None:
538
+ schema["type_"] = type_.upper()
539
+
540
+ format_ = schema.pop("format", None)
541
+ if format_ is not None:
542
+ schema["format_"] = format_
543
+
544
+ items = schema.pop("items", None)
545
+ if items is not None:
546
+ schema["items"] = _rename_schema_fields(items)
547
+
548
+ properties = schema.pop("properties", None)
549
+ if properties is not None:
550
+ schema["properties"] = {k: _rename_schema_fields(v) for k, v in properties.items()}
551
+
552
+ return schema
553
+
554
+
555
+ class FunctionDeclaration:
556
+ def __init__(self, *, name: str, description: str, parameters: dict[str, Any] | None = None):
557
+ """A class wrapping a `protos.FunctionDeclaration`, describes a function for `genai.GenerativeModel`'s `tools`."""
558
+ self._proto = protos.FunctionDeclaration(
559
+ name=name, description=description, parameters=_rename_schema_fields(parameters)
560
+ )
561
+
562
+ @property
563
+ def name(self) -> str:
564
+ return self._proto.name
565
+
566
+ @property
567
+ def description(self) -> str:
568
+ return self._proto.description
569
+
570
+ @property
571
+ def parameters(self) -> protos.Schema:
572
+ return self._proto.parameters
573
+
574
+ @classmethod
575
+ def from_proto(cls, proto) -> FunctionDeclaration:
576
+ self = cls(name="", description="", parameters={})
577
+ self._proto = proto
578
+ return self
579
+
580
+ def to_proto(self) -> protos.FunctionDeclaration:
581
+ return self._proto
582
+
583
+ @staticmethod
584
+ def from_function(function: Callable[..., Any], descriptions: dict[str, str] | None = None):
585
+ """Builds a `CallableFunctionDeclaration` from a python function.
586
+
587
+ The function should have type annotations.
588
+
589
+ This method is able to generate the schema for arguments annotated with types:
590
+
591
+ `AllowedTypes = float | int | str | list[AllowedTypes] | dict`
592
+
593
+ This method does not yet build a schema for `TypedDict`, that would allow you to specify the dictionary
594
+ contents. But you can build these manually.
595
+ """
596
+
597
+ if descriptions is None:
598
+ descriptions = {}
599
+
600
+ schema = _schema_for_function(function, descriptions=descriptions)
601
+
602
+ return CallableFunctionDeclaration(**schema, function=function)
603
+
604
+
605
+ StructType = dict[str, "ValueType"]
606
+ ValueType = Union[float, str, bool, StructType, list["ValueType"], None]
607
+
608
+
609
+ class CallableFunctionDeclaration(FunctionDeclaration):
610
+ """An extension of `FunctionDeclaration` that can be built from a python function, and is callable.
611
+
612
+ Note: The python function must have type annotations.
613
+ """
614
+
615
+ def __init__(
616
+ self,
617
+ *,
618
+ name: str,
619
+ description: str,
620
+ parameters: dict[str, Any] | None = None,
621
+ function: Callable[..., Any],
622
+ ):
623
+ super().__init__(name=name, description=description, parameters=parameters)
624
+ self.function = function
625
+
626
+ def __call__(self, fc: protos.FunctionCall) -> protos.FunctionResponse:
627
+ result = self.function(**fc.args)
628
+ if not isinstance(result, dict):
629
+ result = {"result": result}
630
+ return protos.FunctionResponse(name=fc.name, response=result)
631
+
632
+
633
+ FunctionDeclarationType = Union[
634
+ FunctionDeclaration,
635
+ protos.FunctionDeclaration,
636
+ dict[str, Any],
637
+ Callable[..., Any],
638
+ ]
639
+
640
+
641
+ def _make_function_declaration(
642
+ fun: FunctionDeclarationType,
643
+ ) -> FunctionDeclaration | protos.FunctionDeclaration:
644
+ if isinstance(fun, (FunctionDeclaration, protos.FunctionDeclaration)):
645
+ return fun
646
+ elif isinstance(fun, dict):
647
+ if "function" in fun:
648
+ return CallableFunctionDeclaration(**fun)
649
+ else:
650
+ return FunctionDeclaration(**fun)
651
+ elif callable(fun):
652
+ return CallableFunctionDeclaration.from_function(fun)
653
+ else:
654
+ raise TypeError(
655
+ "Invalid input type. Expected an instance of `genai.FunctionDeclarationType`.\n"
656
+ f"However, received an object of type: {type(fun)}.\n"
657
+ f"Object Value: {fun}"
658
+ )
659
+
660
+
661
+ def _encode_fd(fd: FunctionDeclaration | protos.FunctionDeclaration) -> protos.FunctionDeclaration:
662
+ if isinstance(fd, protos.FunctionDeclaration):
663
+ return fd
664
+
665
+ return fd.to_proto()
666
+
667
+
668
+ class DynamicRetrievalConfigDict(TypedDict):
669
+ mode: protos.DynamicRetrievalConfig.mode
670
+ dynamic_threshold: float
671
+
672
+
673
+ DynamicRetrievalConfig = Union[protos.DynamicRetrievalConfig, DynamicRetrievalConfigDict]
674
+
675
+
676
+ class GoogleSearchRetrievalDict(TypedDict):
677
+ dynamic_retrieval_config: DynamicRetrievalConfig
678
+
679
+
680
+ GoogleSearchRetrievalType = Union[protos.GoogleSearchRetrieval, GoogleSearchRetrievalDict]
681
+
682
+
683
+ def _make_google_search_retrieval(gsr: GoogleSearchRetrievalType):
684
+ if isinstance(gsr, protos.GoogleSearchRetrieval):
685
+ return gsr
686
+ elif isinstance(gsr, Mapping):
687
+ drc = gsr.get("dynamic_retrieval_config", None)
688
+ if drc is not None and isinstance(drc, Mapping):
689
+ mode = drc.get("mode", None)
690
+ if mode is not None:
691
+ mode = to_mode(mode)
692
+ gsr = gsr.copy()
693
+ gsr["dynamic_retrieval_config"]["mode"] = mode
694
+ return protos.GoogleSearchRetrieval(gsr)
695
+ else:
696
+ raise TypeError(
697
+ "Invalid input type. Expected an instance of `genai.GoogleSearchRetrieval`.\n"
698
+ f"However, received an object of type: {type(gsr)}.\n"
699
+ f"Object Value: {gsr}"
700
+ )
701
+
702
+
703
+ class Tool:
704
+ """A wrapper for `protos.Tool`, Contains a collection of related `FunctionDeclaration` objects,
705
+ protos.CodeExecution object, and protos.GoogleSearchRetrieval object."""
706
+
707
+ def __init__(
708
+ self,
709
+ *,
710
+ function_declarations: Iterable[FunctionDeclarationType] | None = None,
711
+ google_search_retrieval: GoogleSearchRetrievalType | None = None,
712
+ code_execution: protos.CodeExecution | None = None,
713
+ ):
714
+ # The main path doesn't use this but is seems useful.
715
+ if function_declarations is not None:
716
+ self._function_declarations = [
717
+ _make_function_declaration(f) for f in function_declarations
718
+ ]
719
+ self._index = {}
720
+ for fd in self._function_declarations:
721
+ name = fd.name
722
+ if name in self._index:
723
+ raise ValueError("")
724
+ self._index[fd.name] = fd
725
+ else:
726
+ # Consistent fields
727
+ self._function_declarations = []
728
+ self._index = {}
729
+
730
+ if google_search_retrieval is not None:
731
+ self._google_search_retrieval = _make_google_search_retrieval(google_search_retrieval)
732
+ else:
733
+ self._google_search_retrieval = None
734
+
735
+ self._proto = protos.Tool(
736
+ function_declarations=[_encode_fd(fd) for fd in self._function_declarations],
737
+ google_search_retrieval=google_search_retrieval,
738
+ code_execution=code_execution,
739
+ )
740
+
741
+ @property
742
+ def function_declarations(self) -> list[FunctionDeclaration | protos.FunctionDeclaration]:
743
+ return self._function_declarations
744
+
745
+ @property
746
+ def google_search_retrieval(self) -> protos.GoogleSearchRetrieval:
747
+ return self._google_search_retrieval
748
+
749
+ @property
750
+ def code_execution(self) -> protos.CodeExecution:
751
+ return self._proto.code_execution
752
+
753
+ def __getitem__(
754
+ self, name: str | protos.FunctionCall
755
+ ) -> FunctionDeclaration | protos.FunctionDeclaration:
756
+ if not isinstance(name, str):
757
+ name = name.name
758
+
759
+ return self._index[name]
760
+
761
+ def __call__(self, fc: protos.FunctionCall) -> protos.FunctionResponse | None:
762
+ declaration = self[fc]
763
+ if not callable(declaration):
764
+ return None
765
+
766
+ return declaration(fc)
767
+
768
+ def to_proto(self):
769
+ return self._proto
770
+
771
+
772
+ class ToolDict(TypedDict):
773
+ function_declarations: list[FunctionDeclarationType]
774
+
775
+
776
+ ToolType = Union[
777
+ str, Tool, protos.Tool, ToolDict, Iterable[FunctionDeclarationType], FunctionDeclarationType
778
+ ]
779
+
780
+
781
+ def _make_tool(tool: ToolType) -> Tool:
782
+ if isinstance(tool, Tool):
783
+ return tool
784
+ elif isinstance(tool, protos.Tool):
785
+ if "code_execution" in tool:
786
+ code_execution = tool.code_execution
787
+ else:
788
+ code_execution = None
789
+
790
+ if "google_search_retrieval" in tool:
791
+ google_search_retrieval = tool.google_search_retrieval
792
+ else:
793
+ google_search_retrieval = None
794
+
795
+ return Tool(
796
+ function_declarations=tool.function_declarations,
797
+ google_search_retrieval=google_search_retrieval,
798
+ code_execution=code_execution,
799
+ )
800
+ elif isinstance(tool, dict):
801
+ if (
802
+ "function_declarations" in tool
803
+ or "google_search_retrieval" in tool
804
+ or "code_execution" in tool
805
+ ):
806
+ return Tool(**tool)
807
+ else:
808
+ fd = tool
809
+ return Tool(function_declarations=[protos.FunctionDeclaration(**fd)])
810
+ elif isinstance(tool, str):
811
+ if tool.lower() == "code_execution":
812
+ return Tool(code_execution=protos.CodeExecution())
813
+ # Check to see if one of the mode enums matches
814
+ elif tool.lower() == "google_search_retrieval":
815
+ return Tool(google_search_retrieval=protos.GoogleSearchRetrieval())
816
+ else:
817
+ raise ValueError(
818
+ "The only string that can be passed as a tool is 'code_execution', or one of the specified values for the `mode` parameter for google_search_retrieval."
819
+ )
820
+ elif isinstance(tool, protos.CodeExecution):
821
+ return Tool(code_execution=tool)
822
+ elif isinstance(tool, protos.GoogleSearchRetrieval):
823
+ return Tool(google_search_retrieval=tool)
824
+ elif isinstance(tool, Iterable):
825
+ return Tool(function_declarations=tool)
826
+ else:
827
+ try:
828
+ return Tool(function_declarations=[tool])
829
+ except Exception as e:
830
+ raise TypeError(
831
+ "Invalid input type. Expected an instance of `genai.ToolType`.\n"
832
+ f"However, received an object of type: {type(tool)}.\n"
833
+ f"Object Value: {tool}"
834
+ ) from e
835
+
836
+
837
+ class FunctionLibrary:
838
+ """A container for a set of `Tool` objects, manages lookup and execution of their functions."""
839
+
840
+ def __init__(self, tools: Iterable[ToolType]):
841
+ tools = _make_tools(tools)
842
+ self._tools = list(tools)
843
+ self._index = {}
844
+ for tool in self._tools:
845
+ for declaration in tool.function_declarations:
846
+ name = declaration.name
847
+ if name in self._index:
848
+ raise ValueError(
849
+ f"Invalid operation: A `FunctionDeclaration` named '{name}' is already defined. "
850
+ "Each `FunctionDeclaration` must have a unique name. Please use a different name."
851
+ )
852
+ self._index[declaration.name] = declaration
853
+
854
+ def __getitem__(
855
+ self, name: str | protos.FunctionCall
856
+ ) -> FunctionDeclaration | protos.FunctionDeclaration:
857
+ if not isinstance(name, str):
858
+ name = name.name
859
+
860
+ return self._index[name]
861
+
862
+ def __call__(self, fc: protos.FunctionCall) -> protos.Part | None:
863
+ declaration = self[fc]
864
+ if not callable(declaration):
865
+ return None
866
+
867
+ response = declaration(fc)
868
+ return protos.Part(function_response=response)
869
+
870
+ def to_proto(self):
871
+ return [tool.to_proto() for tool in self._tools]
872
+
873
+
874
+ ToolsType = Union[Iterable[ToolType], ToolType]
875
+
876
+
877
+ def _make_tools(tools: ToolsType) -> list[Tool]:
878
+ if isinstance(tools, str):
879
+ if tools.lower() == "code_execution" or tools.lower() == "google_search_retrieval":
880
+ return [_make_tool(tools)]
881
+ else:
882
+ raise ValueError("The only string that can be passed as a tool is 'code_execution'.")
883
+ elif isinstance(tools, Iterable) and not isinstance(tools, Mapping):
884
+ tools = [_make_tool(t) for t in tools]
885
+ if len(tools) > 1 and all(len(t.function_declarations) == 1 for t in tools):
886
+ # flatten into a single tool.
887
+ tools = [_make_tool([t.function_declarations[0] for t in tools])]
888
+ return tools
889
+ else:
890
+ tool = tools
891
+ return [_make_tool(tool)]
892
+
893
+
894
+ FunctionLibraryType = Union[FunctionLibrary, ToolsType]
895
+
896
+
897
+ def to_function_library(lib: FunctionLibraryType | None) -> FunctionLibrary | None:
898
+ if lib is None:
899
+ return lib
900
+ elif isinstance(lib, FunctionLibrary):
901
+ return lib
902
+ else:
903
+ return FunctionLibrary(tools=lib)
904
+
905
+
906
+ FunctionCallingMode = protos.FunctionCallingConfig.Mode
907
+
908
+ # fmt: off
909
+ _FUNCTION_CALLING_MODE = {
910
+ 1: FunctionCallingMode.AUTO,
911
+ FunctionCallingMode.AUTO: FunctionCallingMode.AUTO,
912
+ "mode_auto": FunctionCallingMode.AUTO,
913
+ "auto": FunctionCallingMode.AUTO,
914
+
915
+ 2: FunctionCallingMode.ANY,
916
+ FunctionCallingMode.ANY: FunctionCallingMode.ANY,
917
+ "mode_any": FunctionCallingMode.ANY,
918
+ "any": FunctionCallingMode.ANY,
919
+
920
+ 3: FunctionCallingMode.NONE,
921
+ FunctionCallingMode.NONE: FunctionCallingMode.NONE,
922
+ "mode_none": FunctionCallingMode.NONE,
923
+ "none": FunctionCallingMode.NONE,
924
+ }
925
+ # fmt: on
926
+
927
+ FunctionCallingModeType = Union[FunctionCallingMode, str, int]
928
+
929
+
930
+ def to_function_calling_mode(x: FunctionCallingModeType) -> FunctionCallingMode:
931
+ if isinstance(x, str):
932
+ x = x.lower()
933
+ return _FUNCTION_CALLING_MODE[x]
934
+
935
+
936
+ class FunctionCallingConfigDict(TypedDict):
937
+ mode: FunctionCallingModeType
938
+ allowed_function_names: list[str]
939
+
940
+
941
+ FunctionCallingConfigType = Union[
942
+ FunctionCallingModeType, FunctionCallingConfigDict, protos.FunctionCallingConfig
943
+ ]
944
+
945
+
946
+ def to_function_calling_config(obj: FunctionCallingConfigType) -> protos.FunctionCallingConfig:
947
+ if isinstance(obj, protos.FunctionCallingConfig):
948
+ return obj
949
+ elif isinstance(obj, (FunctionCallingMode, str, int)):
950
+ obj = {"mode": to_function_calling_mode(obj)}
951
+ elif isinstance(obj, dict):
952
+ obj = obj.copy()
953
+ mode = obj.pop("mode")
954
+ obj["mode"] = to_function_calling_mode(mode)
955
+ else:
956
+ raise TypeError(
957
+ "Invalid input type. Failed to convert input to `protos.FunctionCallingConfig`.\n"
958
+ f"Received an object of type: {type(obj)}.\n"
959
+ f"Object Value: {obj}"
960
+ )
961
+
962
+ return protos.FunctionCallingConfig(obj)
963
+
964
+
965
+ class ToolConfigDict:
966
+ function_calling_config: FunctionCallingConfigType
967
+
968
+
969
+ ToolConfigType = Union[ToolConfigDict, protos.ToolConfig]
970
+
971
+
972
+ def to_tool_config(obj: ToolConfigType) -> protos.ToolConfig:
973
+ if isinstance(obj, protos.ToolConfig):
974
+ return obj
975
+ elif isinstance(obj, dict):
976
+ fcc = obj.pop("function_calling_config")
977
+ fcc = to_function_calling_config(fcc)
978
+ obj["function_calling_config"] = fcc
979
+ return protos.ToolConfig(**obj)
980
+ else:
981
+ raise TypeError(
982
+ "Invalid input type. Failed to convert input to `protos.ToolConfig`.\n"
983
+ f"Received an object of type: {type(obj)}.\n"
984
+ f"Object Value: {obj}"
985
+ )
.venv/lib/python3.11/site-packages/google/generativeai/types/discuss_types.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright 2023 Google LLC
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Type definitions for the discuss service."""
16
+
17
+ import abc
18
+ import dataclasses
19
+ from typing import Any, Dict, Union, Iterable, Optional, Tuple, List
20
+ from typing_extensions import TypedDict
21
+
22
+ import google.ai.generativelanguage as glm
23
+ from google.generativeai import string_utils
24
+
25
+ from google.generativeai.types import safety_types
26
+ from google.generativeai.types import citation_types
27
+
28
+
29
+ __all__ = [
30
+ "MessageDict",
31
+ "MessageOptions",
32
+ "MessagesOptions",
33
+ "ExampleDict",
34
+ "ExampleOptions",
35
+ "ExamplesOptions",
36
+ "MessagePromptDict",
37
+ "MessagePromptOptions",
38
+ "ResponseDict",
39
+ "ChatResponse",
40
+ "AuthorError",
41
+ ]
42
+
43
+
44
+ class TokenCount(TypedDict):
45
+ token_count: int
46
+
47
+
48
+ class MessageDict(TypedDict):
49
+ """A dict representation of a `glm.Message`."""
50
+
51
+ author: str
52
+ content: str
53
+ citation_metadata: Optional[citation_types.CitationMetadataDict]
54
+
55
+
56
+ MessageOptions = Union[str, MessageDict, glm.Message]
57
+ MESSAGE_OPTIONS = (str, dict, glm.Message)
58
+
59
+ MessagesOptions = Union[
60
+ MessageOptions,
61
+ Iterable[MessageOptions],
62
+ ]
63
+ MESSAGES_OPTIONS = (MESSAGE_OPTIONS, Iterable)
64
+
65
+
66
+ class ExampleDict(TypedDict):
67
+ """A dict representation of a `glm.Example`."""
68
+
69
+ input: MessageOptions
70
+ output: MessageOptions
71
+
72
+
73
+ ExampleOptions = Union[
74
+ Tuple[MessageOptions, MessageOptions],
75
+ Iterable[MessageOptions],
76
+ ExampleDict,
77
+ glm.Example,
78
+ ]
79
+ EXAMPLE_OPTIONS = (glm.Example, dict, Iterable)
80
+ ExamplesOptions = Union[ExampleOptions, Iterable[ExampleOptions]]
81
+
82
+
83
+ class MessagePromptDict(TypedDict, total=False):
84
+ """A dict representation of a `glm.MessagePrompt`."""
85
+
86
+ context: str
87
+ examples: ExamplesOptions
88
+ messages: MessagesOptions
89
+
90
+
91
+ MessagePromptOptions = Union[
92
+ str,
93
+ glm.Message,
94
+ Iterable[Union[str, glm.Message]],
95
+ MessagePromptDict,
96
+ glm.MessagePrompt,
97
+ ]
98
+ MESSAGE_PROMPT_KEYS = {"context", "examples", "messages"}
99
+
100
+
101
+ class ResponseDict(TypedDict):
102
+ """A dict representation of a `glm.GenerateMessageResponse`."""
103
+
104
+ messages: List[MessageDict]
105
+ candidates: List[MessageDict]
106
+
107
+
108
+ @string_utils.prettyprint
109
+ @dataclasses.dataclass(init=False)
110
+ class ChatResponse(abc.ABC):
111
+ """A chat response from the model.
112
+
113
+ * Use `response.last` (settable) for easy access to the text of the last response.
114
+ (`messages[-1]['content']`)
115
+ * Use `response.messages` to access the message history (including `.last`).
116
+ * Use `response.candidates` to access all the responses generated by the model.
117
+
118
+ Other attributes are just saved from the arguments to `genai.chat`, so you
119
+ can easily continue a conversation:
120
+
121
+ ```
122
+ import google.generativeai as genai
123
+
124
+ genai.configure(api_key=os.environ['GOOGLE_API_KEY'])
125
+
126
+ response = genai.chat(messages=["Hello."])
127
+ print(response.last) # 'Hello! What can I help you with?'
128
+ response.reply("Can you tell me a joke?")
129
+ ```
130
+
131
+ See `genai.chat` for more details.
132
+
133
+ Attributes:
134
+ candidates: A list of candidate responses from the model.
135
+
136
+ The top candidate is appended to the `messages` field.
137
+
138
+ This list will contain a *maximum* of `candidate_count` candidates.
139
+ It may contain fewer (duplicates are dropped), it will contain at least one.
140
+
141
+ Note: The `temperature` field affects the variability of the responses. Low
142
+ temperatures will return few candidates. Setting `temperature=0` is deterministic,
143
+ so it will only ever return one candidate.
144
+ filters: This indicates which `types.SafetyCategory`(s) blocked a
145
+ candidate from this response, the lowest `types.HarmProbability`
146
+ that triggered a block, and the `types.HarmThreshold` setting for that category.
147
+ This indicates the smallest change to the `types.SafetySettings` that would be
148
+ necessary to unblock at least 1 response.
149
+
150
+ The blocking is configured by the `types.SafetySettings` in the request (or the
151
+ default `types.SafetySettings` of the API).
152
+ messages: Contains all the `messages` that were passed when the model was called,
153
+ plus the top `candidate` message.
154
+ model: The model name.
155
+ context: Text that should be provided to the model first, to ground the response.
156
+ examples: Examples of what the model should generate.
157
+ messages: A snapshot of the conversation history sorted chronologically.
158
+ temperature: Controls the randomness of the output. Must be positive.
159
+ candidate_count: The **maximum** number of generated response messages to return.
160
+ top_k: The maximum number of tokens to consider when sampling.
161
+ top_p: The maximum cumulative probability of tokens to consider when sampling.
162
+
163
+ """
164
+
165
+ model: str
166
+ context: str
167
+ examples: List[ExampleDict]
168
+ messages: List[Optional[MessageDict]]
169
+ temperature: Optional[float]
170
+ candidate_count: Optional[int]
171
+ candidates: List[MessageDict]
172
+ filters: List[safety_types.ContentFilterDict]
173
+ top_p: Optional[float] = None
174
+ top_k: Optional[float] = None
175
+
176
+ @property
177
+ @abc.abstractmethod
178
+ def last(self) -> Optional[str]:
179
+ """A settable property that provides simple access to the last response string
180
+
181
+ A shortcut for `response.messages[0]['content']`.
182
+ """
183
+ pass
184
+
185
+ def to_dict(self) -> Dict[str, Any]:
186
+ result = {
187
+ "model": self.model,
188
+ "context": self.context,
189
+ "examples": self.examples,
190
+ "messages": self.messages,
191
+ "temperature": self.temperature,
192
+ "candidate_count": self.candidate_count,
193
+ "top_p": self.top_p,
194
+ "top_k": self.top_k,
195
+ "candidates": self.candidates,
196
+ }
197
+ return result
198
+
199
+ @abc.abstractmethod
200
+ def reply(self, message: MessageOptions) -> "ChatResponse":
201
+ "Add a message to the conversation, and get the model's response."
202
+ pass
203
+
204
+
205
+ class AuthorError(Exception):
206
+ """Raised by the `chat` (or `reply`) functions when the author list can't be normalized."""
207
+
208
+ pass
.venv/lib/python3.11/site-packages/google/generativeai/types/file_types.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright 2023 Google LLC
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ from __future__ import annotations
16
+
17
+ import datetime
18
+ from typing import Any, Union
19
+ from typing_extensions import TypedDict
20
+
21
+ from google.rpc.status_pb2 import Status
22
+ from google.generativeai.client import get_default_file_client
23
+
24
+ from google.generativeai import protos
25
+
26
+ import pprint
27
+
28
+
29
+ class File:
30
+ def __init__(self, proto: protos.File | File | dict):
31
+ if isinstance(proto, File):
32
+ proto = proto.to_proto()
33
+ self._proto = protos.File(proto)
34
+
35
+ def to_proto(self) -> protos.File:
36
+ return self._proto
37
+
38
+ def to_dict(self) -> dict[str, Any]:
39
+ return type(self._proto).to_dict(self._proto, use_integers_for_enums=False)
40
+
41
+ def __str__(self):
42
+ def sort_key(pair):
43
+ name, value = pair
44
+ if name == "name":
45
+ return ""
46
+ elif "time" in name:
47
+ return "zz_" + name
48
+ else:
49
+ return name
50
+
51
+ dict_format = dict(sorted(self.to_dict().items(), key=sort_key))
52
+ dict_format = pprint.pformat(dict_format, sort_dicts=False)
53
+ dict_format = "{\n " + dict_format[1:]
54
+ dict_format = "\n ".join(dict_format.splitlines())
55
+ return dict_format.join(["genai.File(", ")"])
56
+
57
+ __repr__ = __str__
58
+
59
+ @property
60
+ def name(self) -> str:
61
+ return self._proto.name
62
+
63
+ @property
64
+ def display_name(self) -> str:
65
+ return self._proto.display_name
66
+
67
+ @property
68
+ def mime_type(self) -> str:
69
+ return self._proto.mime_type
70
+
71
+ @property
72
+ def size_bytes(self) -> int:
73
+ return self._proto.size_bytes
74
+
75
+ @property
76
+ def create_time(self) -> datetime.datetime:
77
+ return self._proto.create_time
78
+
79
+ @property
80
+ def update_time(self) -> datetime.datetime:
81
+ return self._proto.update_time
82
+
83
+ @property
84
+ def expiration_time(self) -> datetime.datetime:
85
+ return self._proto.expiration_time
86
+
87
+ @property
88
+ def sha256_hash(self) -> bytes:
89
+ return self._proto.sha256_hash
90
+
91
+ @property
92
+ def uri(self) -> str:
93
+ return self._proto.uri
94
+
95
+ @property
96
+ def state(self) -> protos.File.State:
97
+ return self._proto.state
98
+
99
+ @property
100
+ def video_metadata(self) -> protos.VideoMetadata:
101
+ return self._proto.video_metadata
102
+
103
+ @property
104
+ def error(self) -> Status:
105
+ return self._proto.error
106
+
107
+ def delete(self):
108
+ client = get_default_file_client()
109
+ client.delete_file(name=self.name)
110
+
111
+
112
+ class FileDataDict(TypedDict):
113
+ mime_type: str
114
+ file_uri: str
115
+
116
+
117
+ FileDataType = Union[FileDataDict, protos.FileData, protos.File, File]
118
+
119
+
120
+ def to_file_data(file_data: FileDataType):
121
+ if isinstance(file_data, dict):
122
+ if "file_uri" in file_data:
123
+ file_data = protos.FileData(file_data)
124
+ else:
125
+ file_data = protos.File(file_data)
126
+
127
+ if isinstance(file_data, File):
128
+ file_data = file_data.to_proto()
129
+
130
+ if isinstance(file_data, protos.File):
131
+ file_data = protos.FileData(
132
+ mime_type=file_data.mime_type,
133
+ file_uri=file_data.uri,
134
+ )
135
+
136
+ if isinstance(file_data, protos.FileData):
137
+ return file_data
138
+ else:
139
+ raise TypeError(
140
+ f"Invalid input type. Failed to convert input to `FileData`.\n"
141
+ f"Received an object of type: {type(file_data)}.\n"
142
+ f"Object Value: {file_data}"
143
+ )
.venv/lib/python3.11/site-packages/google/generativeai/types/generation_types.py ADDED
@@ -0,0 +1,759 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright 2023 Google LLC
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ from __future__ import annotations
16
+
17
+ import collections
18
+ import contextlib
19
+ from collections.abc import Iterable, AsyncIterable, Mapping
20
+ import dataclasses
21
+ import itertools
22
+ import json
23
+ import sys
24
+ import textwrap
25
+ from typing import Union, Any
26
+ from typing_extensions import TypedDict
27
+ import types
28
+
29
+ import google.protobuf.json_format
30
+ import google.api_core.exceptions
31
+
32
+ from google.generativeai import protos
33
+ from google.generativeai import string_utils
34
+ from google.generativeai.types import content_types
35
+ from google.generativeai.responder import _rename_schema_fields
36
+
37
+ __all__ = [
38
+ "AsyncGenerateContentResponse",
39
+ "BlockedPromptException",
40
+ "StopCandidateException",
41
+ "IncompleteIterationError",
42
+ "BrokenResponseError",
43
+ "GenerationConfigDict",
44
+ "GenerationConfigType",
45
+ "GenerationConfig",
46
+ "GenerateContentResponse",
47
+ ]
48
+
49
+ if sys.version_info < (3, 10):
50
+
51
+ def aiter(obj):
52
+ return obj.__aiter__()
53
+
54
+ async def anext(obj, default=None):
55
+ try:
56
+ return await obj.__anext__()
57
+ except StopAsyncIteration:
58
+ if default is not None:
59
+ return default
60
+ else:
61
+ raise
62
+
63
+
64
+ class BlockedPromptException(Exception):
65
+ pass
66
+
67
+
68
+ class StopCandidateException(Exception):
69
+ pass
70
+
71
+
72
+ class IncompleteIterationError(Exception):
73
+ pass
74
+
75
+
76
+ class BrokenResponseError(Exception):
77
+ pass
78
+
79
+
80
+ class GenerationConfigDict(TypedDict, total=False):
81
+ # TODO(markdaoust): Python 3.11+ use `NotRequired`, ref: https://peps.python.org/pep-0655/
82
+ candidate_count: int
83
+ stop_sequences: Iterable[str]
84
+ max_output_tokens: int
85
+ temperature: float
86
+ response_mime_type: str
87
+ response_schema: protos.Schema | Mapping[str, Any] # fmt: off
88
+ presence_penalty: float
89
+ frequency_penalty: float
90
+
91
+
92
+ @dataclasses.dataclass
93
+ class GenerationConfig:
94
+ """A simple dataclass used to configure the generation parameters of `GenerativeModel.generate_content`.
95
+
96
+ Attributes:
97
+ candidate_count:
98
+ Number of generated responses to return.
99
+ stop_sequences:
100
+ The set of character sequences (up
101
+ to 5) that will stop output generation. If
102
+ specified, the API will stop at the first
103
+ appearance of a stop sequence. The stop sequence
104
+ will not be included as part of the response.
105
+ max_output_tokens:
106
+ The maximum number of tokens to include in a
107
+ candidate.
108
+
109
+ If unset, this will default to output_token_limit specified
110
+ in the model's specification.
111
+ temperature:
112
+ Controls the randomness of the output. Note: The
113
+ default value varies by model, see the `Model.temperature`
114
+ attribute of the `Model` returned the `genai.get_model`
115
+ function.
116
+
117
+ Values can range from [0.0,1.0], inclusive. A value closer
118
+ to 1.0 will produce responses that are more varied and
119
+ creative, while a value closer to 0.0 will typically result
120
+ in more straightforward responses from the model.
121
+ top_p:
122
+ Optional. The maximum cumulative probability of tokens to
123
+ consider when sampling.
124
+
125
+ The model uses combined Top-k and nucleus sampling.
126
+
127
+ Tokens are sorted based on their assigned probabilities so
128
+ that only the most likely tokens are considered. Top-k
129
+ sampling directly limits the maximum number of tokens to
130
+ consider, while Nucleus sampling limits number of tokens
131
+ based on the cumulative probability.
132
+
133
+ Note: The default value varies by model, see the
134
+ `Model.top_p` attribute of the `Model` returned the
135
+ `genai.get_model` function.
136
+
137
+ top_k (int):
138
+ Optional. The maximum number of tokens to consider when
139
+ sampling.
140
+
141
+ The model uses combined Top-k and nucleus sampling.
142
+
143
+ Top-k sampling considers the set of `top_k` most probable
144
+ tokens. Defaults to 40.
145
+
146
+ Note: The default value varies by model, see the
147
+ `Model.top_k` attribute of the `Model` returned the
148
+ `genai.get_model` function.
149
+ response_mime_type:
150
+ Optional. Output response mimetype of the generated candidate text.
151
+
152
+ Supported mimetype:
153
+ `text/plain`: (default) Text output.
154
+ `text/x-enum`: for use with a string-enum in `response_schema`
155
+ `application/json`: JSON response in the candidates.
156
+
157
+ response_schema:
158
+ Optional. Specifies the format of the JSON requested if response_mime_type is
159
+ `application/json`.
160
+ presence_penalty:
161
+ Optional.
162
+ frequency_penalty:
163
+ Optional.
164
+ """
165
+
166
+ candidate_count: int | None = None
167
+ stop_sequences: Iterable[str] | None = None
168
+ max_output_tokens: int | None = None
169
+ temperature: float | None = None
170
+ top_p: float | None = None
171
+ top_k: int | None = None
172
+ response_mime_type: str | None = None
173
+ response_schema: protos.Schema | Mapping[str, Any] | type | None = None
174
+ presence_penalty: float | None = None
175
+ frequency_penalty: float | None = None
176
+
177
+
178
+ GenerationConfigType = Union[protos.GenerationConfig, GenerationConfigDict, GenerationConfig]
179
+
180
+
181
+ def _normalize_schema(generation_config):
182
+ # Convert response_schema to protos.Schema for request
183
+ response_schema = generation_config.get("response_schema", None)
184
+ if response_schema is None:
185
+ return
186
+
187
+ if isinstance(response_schema, protos.Schema):
188
+ return
189
+
190
+ if isinstance(response_schema, type):
191
+ response_schema = content_types._schema_for_class(response_schema)
192
+ elif isinstance(response_schema, types.GenericAlias):
193
+ if not str(response_schema).startswith("list["):
194
+ raise ValueError(
195
+ f"Invalid input: Could not understand the type of '{response_schema}'. "
196
+ "Expected one of the following types: `int`, `float`, `str`, `bool`, `enum`, "
197
+ "`typing_extensions.TypedDict`, `dataclass` or `list[...]`."
198
+ )
199
+ response_schema = content_types._schema_for_class(response_schema)
200
+
201
+ response_schema = _rename_schema_fields(response_schema)
202
+ generation_config["response_schema"] = protos.Schema(response_schema)
203
+
204
+
205
+ def to_generation_config_dict(generation_config: GenerationConfigType):
206
+ if generation_config is None:
207
+ return {}
208
+ elif isinstance(generation_config, protos.GenerationConfig):
209
+ schema = generation_config.response_schema
210
+ generation_config = type(generation_config).to_dict(
211
+ generation_config
212
+ ) # pytype: disable=attribute-error
213
+ generation_config["response_schema"] = schema
214
+ return generation_config
215
+ elif isinstance(generation_config, GenerationConfig):
216
+ generation_config = dataclasses.asdict(generation_config)
217
+ _normalize_schema(generation_config)
218
+ return {key: value for key, value in generation_config.items() if value is not None}
219
+ elif hasattr(generation_config, "keys"):
220
+ generation_config = dict(generation_config)
221
+ _normalize_schema(generation_config)
222
+ return generation_config
223
+ else:
224
+ raise TypeError(
225
+ "Invalid input type. Expected a `dict` or `GenerationConfig` for `generation_config`.\n"
226
+ f"However, received an object of type: {type(generation_config)}.\n"
227
+ f"Object Value: {generation_config}"
228
+ )
229
+
230
+
231
+ def _join_citation_metadatas(
232
+ citation_metadatas: Iterable[protos.CitationMetadata],
233
+ ):
234
+ citation_metadatas = list(citation_metadatas)
235
+ return citation_metadatas[-1]
236
+
237
+
238
+ def _join_safety_ratings_lists(
239
+ safety_ratings_lists: Iterable[list[protos.SafetyRating]],
240
+ ):
241
+ ratings = {}
242
+ blocked = collections.defaultdict(list)
243
+
244
+ for safety_ratings_list in safety_ratings_lists:
245
+ for rating in safety_ratings_list:
246
+ ratings[rating.category] = rating.probability
247
+ blocked[rating.category].append(rating.blocked)
248
+
249
+ blocked = {category: any(blocked) for category, blocked in blocked.items()}
250
+
251
+ safety_list = []
252
+ for (category, probability), blocked in zip(ratings.items(), blocked.values()):
253
+ safety_list.append(
254
+ protos.SafetyRating(category=category, probability=probability, blocked=blocked)
255
+ )
256
+
257
+ return safety_list
258
+
259
+
260
+ def _join_contents(contents: Iterable[protos.Content]):
261
+ contents = tuple(contents)
262
+ roles = [c.role for c in contents if c.role]
263
+ if roles:
264
+ role = roles[0]
265
+ else:
266
+ role = ""
267
+
268
+ parts = []
269
+ for content in contents:
270
+ parts.extend(content.parts)
271
+
272
+ merged_parts = []
273
+ last = parts[0]
274
+ for part in parts[1:]:
275
+ if "text" in last and "text" in part:
276
+ last = protos.Part(text=last.text + part.text)
277
+ continue
278
+
279
+ # Can we merge the new thing into last?
280
+ # If not, put last in list of parts, and new thing becomes last
281
+ if "executable_code" in last and "executable_code" in part:
282
+ last = protos.Part(
283
+ executable_code=_join_executable_code(last.executable_code, part.executable_code)
284
+ )
285
+ continue
286
+
287
+ if "code_execution_result" in last and "code_execution_result" in part:
288
+ last = protos.Part(
289
+ code_execution_result=_join_code_execution_result(
290
+ last.code_execution_result, part.code_execution_result
291
+ )
292
+ )
293
+ continue
294
+
295
+ merged_parts.append(last)
296
+ last = part
297
+
298
+ merged_parts.append(last)
299
+
300
+ return protos.Content(
301
+ role=role,
302
+ parts=merged_parts,
303
+ )
304
+
305
+
306
+ def _join_executable_code(code_1, code_2):
307
+ return protos.ExecutableCode(language=code_1.language, code=code_1.code + code_2.code)
308
+
309
+
310
+ def _join_code_execution_result(result_1, result_2):
311
+ return protos.CodeExecutionResult(
312
+ outcome=result_2.outcome, output=result_1.output + result_2.output
313
+ )
314
+
315
+
316
+ def _join_candidates(candidates: Iterable[protos.Candidate]):
317
+ """Joins stream chunks of a single candidate."""
318
+ candidates = tuple(candidates)
319
+
320
+ index = candidates[0].index # These should all be the same.
321
+
322
+ return protos.Candidate(
323
+ index=index,
324
+ content=_join_contents([c.content for c in candidates]),
325
+ finish_reason=candidates[-1].finish_reason,
326
+ safety_ratings=_join_safety_ratings_lists([c.safety_ratings for c in candidates]),
327
+ citation_metadata=_join_citation_metadatas([c.citation_metadata for c in candidates]),
328
+ token_count=candidates[-1].token_count,
329
+ )
330
+
331
+
332
+ def _join_candidate_lists(candidate_lists: Iterable[list[protos.Candidate]]):
333
+ """Joins stream chunks where each chunk is a list of candidate chunks."""
334
+ # Assuming that is a candidate ends, it is no longer returned in the list of
335
+ # candidates and that's why candidates have an index
336
+ candidates = collections.defaultdict(list)
337
+ for candidate_list in candidate_lists:
338
+ for candidate in candidate_list:
339
+ candidates[candidate.index].append(candidate)
340
+
341
+ new_candidates = []
342
+ for index, candidate_parts in sorted(candidates.items()):
343
+ new_candidates.append(_join_candidates(candidate_parts))
344
+
345
+ return new_candidates
346
+
347
+
348
+ def _join_prompt_feedbacks(
349
+ prompt_feedbacks: Iterable[protos.GenerateContentResponse.PromptFeedback],
350
+ ):
351
+ # Always return the first prompt feedback.
352
+ return next(iter(prompt_feedbacks))
353
+
354
+
355
+ def _join_chunks(chunks: Iterable[protos.GenerateContentResponse]):
356
+ chunks = tuple(chunks)
357
+ if "usage_metadata" in chunks[-1]:
358
+ usage_metadata = chunks[-1].usage_metadata
359
+ else:
360
+ usage_metadata = None
361
+
362
+ if "model_version" in chunks[-1]:
363
+ model_version = chunks[-1].model_version
364
+ else:
365
+ model_version = None
366
+
367
+ return protos.GenerateContentResponse(
368
+ candidates=_join_candidate_lists(c.candidates for c in chunks),
369
+ prompt_feedback=_join_prompt_feedbacks(c.prompt_feedback for c in chunks),
370
+ usage_metadata=usage_metadata,
371
+ model_version=model_version,
372
+ )
373
+
374
+
375
+ _INCOMPLETE_ITERATION_MESSAGE = """\
376
+ Please let the response complete iteration before accessing the final accumulated
377
+ attributes (or call `response.resolve()`)"""
378
+
379
+
380
+ class BaseGenerateContentResponse:
381
+ def __init__(
382
+ self,
383
+ done: bool,
384
+ iterator: (
385
+ None
386
+ | Iterable[protos.GenerateContentResponse]
387
+ | AsyncIterable[protos.GenerateContentResponse]
388
+ ),
389
+ result: protos.GenerateContentResponse,
390
+ chunks: Iterable[protos.GenerateContentResponse] | None = None,
391
+ ):
392
+ self._done = done
393
+ self._iterator = iterator
394
+ self._result = result
395
+ if chunks is None:
396
+ self._chunks = [result]
397
+ else:
398
+ self._chunks = list(chunks)
399
+ if result.prompt_feedback.block_reason:
400
+ self._error = BlockedPromptException(result)
401
+ else:
402
+ self._error = None
403
+
404
+ def to_dict(self):
405
+ """Returns the result as a JSON-compatible dict.
406
+
407
+ Note: This doesn't capture the iterator state when streaming, it only captures the accumulated
408
+ `GenerateContentResponse` fields.
409
+
410
+ >>> import json
411
+ >>> response = model.generate_content('Hello?')
412
+ >>> json.dumps(response.to_dict())
413
+ """
414
+ return type(self._result).to_dict(self._result)
415
+
416
+ @property
417
+ def candidates(self):
418
+ """The list of candidate responses.
419
+
420
+ Raises:
421
+ IncompleteIterationError: With `stream=True` if iteration over the stream was not completed.
422
+ """
423
+ if not self._done:
424
+ raise IncompleteIterationError(_INCOMPLETE_ITERATION_MESSAGE)
425
+ return self._result.candidates
426
+
427
+ @property
428
+ def parts(self):
429
+ """A quick accessor equivalent to `self.candidates[0].content.parts`
430
+
431
+ Raises:
432
+ ValueError: If the candidate list does not contain exactly one candidate.
433
+ """
434
+ candidates = self.candidates
435
+ if not candidates:
436
+ msg = (
437
+ "Invalid operation: The `response.parts` quick accessor requires a single candidate, "
438
+ "but but `response.candidates` is empty."
439
+ )
440
+ if self.prompt_feedback:
441
+ raise ValueError(
442
+ msg + "\nThis appears to be caused by a blocked prompt, "
443
+ f"see `response.prompt_feedback`: {self.prompt_feedback}"
444
+ )
445
+ else:
446
+ raise ValueError(msg)
447
+
448
+ if len(candidates) > 1:
449
+ raise ValueError(
450
+ "Invalid operation: The `response.parts` quick accessor retrieves the parts for a single candidate. "
451
+ "This response contains multiple candidates, please use `result.candidates[index].text`."
452
+ )
453
+ parts = candidates[0].content.parts
454
+ return parts
455
+
456
+ @property
457
+ def text(self):
458
+ """A quick accessor equivalent to `self.candidates[0].content.parts[0].text`
459
+
460
+ Raises:
461
+ ValueError: If the candidate list or parts list does not contain exactly one entry.
462
+ """
463
+ parts = self.parts
464
+ if not parts:
465
+ candidate = self.candidates[0]
466
+
467
+ fr = candidate.finish_reason
468
+ FinishReason = protos.Candidate.FinishReason
469
+
470
+ msg = (
471
+ "Invalid operation: The `response.text` quick accessor requires the response to contain a valid "
472
+ "`Part`, but none were returned. The candidate's "
473
+ f"[finish_reason](https://ai.google.dev/api/generate-content#finishreason) is {fr}."
474
+ )
475
+
476
+ if fr is FinishReason.FINISH_REASON_UNSPECIFIED:
477
+ raise ValueError(msg)
478
+ elif fr is FinishReason.STOP:
479
+ raise ValueError(msg)
480
+ elif fr is FinishReason.MAX_TOKENS:
481
+ raise ValueError(msg)
482
+ elif fr is FinishReason.SAFETY:
483
+ raise ValueError(
484
+ msg + f" The candidate's safety_ratings are: {candidate.safety_ratings}.",
485
+ candidate.safety_ratings,
486
+ )
487
+ elif fr is FinishReason.RECITATION:
488
+ raise ValueError(
489
+ msg + " Meaning that the model was reciting from copyrighted material."
490
+ )
491
+ elif fr is FinishReason.LANGUAGE:
492
+ raise ValueError(msg + " Meaning the response was using an unsupported language.")
493
+ elif fr is FinishReason.OTHER:
494
+ raise ValueError(msg)
495
+ elif fr is FinishReason.BLOCKLIST:
496
+ raise ValueError(msg)
497
+ elif fr is FinishReason.PROHIBITED_CONTENT:
498
+ raise ValueError(msg)
499
+ elif fr is FinishReason.SPII:
500
+ raise ValueError(msg + " SPII - Sensitive Personally Identifiable Information.")
501
+ elif fr is FinishReason.MALFORMED_FUNCTION_CALL:
502
+ raise ValueError(
503
+ msg + " Meaning that model generated a `FunctionCall` that was invalid. "
504
+ "Setting the "
505
+ "[Function calling mode](https://ai.google.dev/gemini-api/docs/function-calling#function_calling_mode) "
506
+ "to `ANY` can fix this because it enables constrained decoding."
507
+ )
508
+ else:
509
+ raise ValueError(msg)
510
+
511
+ texts = []
512
+ for part in parts:
513
+ if "text" in part:
514
+ texts.append(part.text)
515
+ continue
516
+ if "executable_code" in part:
517
+ language = part.executable_code.language.name.lower()
518
+ if language == "language_unspecified":
519
+ language = ""
520
+ else:
521
+ language = f" {language}"
522
+ texts.extend([f"```{language}", part.executable_code.code.lstrip("\n"), "```"])
523
+ continue
524
+ if "code_execution_result" in part:
525
+ outcome_result = part.code_execution_result.outcome.name.lower().replace(
526
+ "outcome_", ""
527
+ )
528
+ if outcome_result == "ok" or outcome_result == "unspecified":
529
+ outcome_result = ""
530
+ else:
531
+ outcome_result = f" {outcome_result}"
532
+ texts.extend([f"```{outcome_result}", part.code_execution_result.output, "```"])
533
+ continue
534
+
535
+ part_type = protos.Part.pb(part).whichOneof("data")
536
+ raise ValueError(f"Could not convert `part.{part_type}` to text.")
537
+
538
+ return "\n".join(texts)
539
+
540
+ @property
541
+ def prompt_feedback(self):
542
+ return self._result.prompt_feedback
543
+
544
+ @property
545
+ def usage_metadata(self):
546
+ return self._result.usage_metadata
547
+
548
+ @property
549
+ def model_version(self):
550
+ return self._result.model_version
551
+
552
+ def __str__(self) -> str:
553
+ if self._done:
554
+ _iterator = "None"
555
+ else:
556
+ _iterator = f"<{self._iterator.__class__.__name__}>"
557
+
558
+ as_dict = type(self._result).to_dict(
559
+ self._result, use_integers_for_enums=False, including_default_value_fields=False
560
+ )
561
+ json_str = json.dumps(as_dict, indent=2)
562
+
563
+ _result = f"protos.GenerateContentResponse({json_str})"
564
+ _result = _result.replace("\n", "\n ")
565
+
566
+ if self._error:
567
+
568
+ _error = f",\nerror={repr(self._error)}"
569
+ else:
570
+ _error = ""
571
+
572
+ return (
573
+ textwrap.dedent(
574
+ f"""\
575
+ response:
576
+ {type(self).__name__}(
577
+ done={self._done},
578
+ iterator={_iterator},
579
+ result={_result},
580
+ )"""
581
+ )
582
+ + _error
583
+ )
584
+
585
+ __repr__ = __str__
586
+
587
+
588
+ @contextlib.contextmanager
589
+ def rewrite_stream_error():
590
+ try:
591
+ yield
592
+ except (google.protobuf.json_format.ParseError, AttributeError) as e:
593
+ raise google.api_core.exceptions.BadRequest(
594
+ "Unknown error trying to retrieve streaming response. "
595
+ "Please retry with `stream=False` for more details."
596
+ )
597
+
598
+
599
+ GENERATE_CONTENT_RESPONSE_DOC = """Instances of this class manage the response of the `generate_content` method.
600
+
601
+ These are returned by `GenerativeModel.generate_content` and `ChatSession.send_message`.
602
+ This object is based on the low level `protos.GenerateContentResponse` class which just has `prompt_feedback`
603
+ and `candidates` attributes. This class adds several quick accessors for common use cases.
604
+
605
+ The same object type is returned for both `stream=True/False`.
606
+
607
+ ### Streaming
608
+
609
+ When you pass `stream=True` to `GenerativeModel.generate_content` or `ChatSession.send_message`,
610
+ iterate over this object to receive chunks of the response:
611
+
612
+ ```
613
+ response = model.generate_content(..., stream=True):
614
+ for chunk in response:
615
+ print(chunk.text)
616
+ ```
617
+
618
+ `GenerateContentResponse.prompt_feedback` is available immediately but
619
+ `GenerateContentResponse.candidates`, and all the attributes derived from them (`.text`, `.parts`),
620
+ are only available after the iteration is complete.
621
+ """
622
+
623
+ ASYNC_GENERATE_CONTENT_RESPONSE_DOC = (
624
+ """This is the async version of `genai.GenerateContentResponse`."""
625
+ )
626
+
627
+
628
+ @string_utils.set_doc(GENERATE_CONTENT_RESPONSE_DOC)
629
+ class GenerateContentResponse(BaseGenerateContentResponse):
630
+ @classmethod
631
+ def from_iterator(cls, iterator: Iterable[protos.GenerateContentResponse]):
632
+ iterator = iter(iterator)
633
+ with rewrite_stream_error():
634
+ response = next(iterator)
635
+
636
+ return cls(
637
+ done=False,
638
+ iterator=iterator,
639
+ result=response,
640
+ )
641
+
642
+ @classmethod
643
+ def from_response(cls, response: protos.GenerateContentResponse):
644
+ return cls(
645
+ done=True,
646
+ iterator=None,
647
+ result=response,
648
+ )
649
+
650
+ def __iter__(self):
651
+ # This is not thread safe.
652
+ if self._done:
653
+ for chunk in self._chunks:
654
+ yield GenerateContentResponse.from_response(chunk)
655
+ return
656
+
657
+ # Always have the next chunk available.
658
+ if len(self._chunks) == 0:
659
+ self._chunks.append(next(self._iterator))
660
+
661
+ for n in itertools.count():
662
+ if self._error:
663
+ raise self._error
664
+
665
+ if n >= len(self._chunks) - 1:
666
+ # Look ahead for a new item, so that you know the stream is done
667
+ # when you yield the last item.
668
+ if self._done:
669
+ return
670
+
671
+ try:
672
+ item = next(self._iterator)
673
+ except StopIteration:
674
+ self._done = True
675
+ except Exception as e:
676
+ self._error = e
677
+ self._done = True
678
+ else:
679
+ self._chunks.append(item)
680
+ self._result = _join_chunks([self._result, item])
681
+
682
+ item = self._chunks[n]
683
+
684
+ item = GenerateContentResponse.from_response(item)
685
+ yield item
686
+
687
+ def resolve(self):
688
+ if self._done:
689
+ return
690
+
691
+ for _ in self:
692
+ pass
693
+
694
+
695
+ @string_utils.set_doc(ASYNC_GENERATE_CONTENT_RESPONSE_DOC)
696
+ class AsyncGenerateContentResponse(BaseGenerateContentResponse):
697
+ @classmethod
698
+ async def from_aiterator(cls, iterator: AsyncIterable[protos.GenerateContentResponse]):
699
+ iterator = aiter(iterator) # type: ignore
700
+ with rewrite_stream_error():
701
+ response = await anext(iterator) # type: ignore
702
+
703
+ return cls(
704
+ done=False,
705
+ iterator=iterator,
706
+ result=response,
707
+ )
708
+
709
+ @classmethod
710
+ def from_response(cls, response: protos.GenerateContentResponse):
711
+ return cls(
712
+ done=True,
713
+ iterator=None,
714
+ result=response,
715
+ )
716
+
717
+ async def __aiter__(self):
718
+ # This is not thread safe.
719
+ if self._done:
720
+ for chunk in self._chunks:
721
+ yield GenerateContentResponse.from_response(chunk)
722
+ return
723
+
724
+ # Always have the next chunk available.
725
+ if len(self._chunks) == 0:
726
+ self._chunks.append(await anext(self._iterator)) # type: ignore
727
+
728
+ for n in itertools.count():
729
+ if self._error:
730
+ raise self._error
731
+
732
+ if n >= len(self._chunks) - 1:
733
+ # Look ahead for a new item, so that you know the stream is done
734
+ # when you yield the last item.
735
+ if self._done:
736
+ return
737
+
738
+ try:
739
+ item = await anext(self._iterator) # type: ignore
740
+ except StopAsyncIteration:
741
+ self._done = True
742
+ except Exception as e:
743
+ self._error = e
744
+ self._done = True
745
+ else:
746
+ self._chunks.append(item)
747
+ self._result = _join_chunks([self._result, item])
748
+
749
+ item = self._chunks[n]
750
+
751
+ item = GenerateContentResponse.from_response(item)
752
+ yield item
753
+
754
+ async def resolve(self):
755
+ if self._done:
756
+ return
757
+
758
+ async for _ in self:
759
+ pass
.venv/lib/python3.11/site-packages/google/generativeai/types/image_types/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (285 Bytes). View file
 
.venv/lib/python3.11/site-packages/google/generativeai/types/image_types/__pycache__/_image_types.cpython-311.pyc ADDED
Binary file (20.5 kB). View file
 
.venv/lib/python3.11/site-packages/google/generativeai/types/image_types/_image_types.py ADDED
@@ -0,0 +1,440 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import base64
4
+ import io
5
+ import json
6
+ import mimetypes
7
+ import os
8
+ import pathlib
9
+ import typing
10
+ from typing import Any, Dict, Optional, Union
11
+
12
+ import httplib2
13
+ import threading
14
+
15
+ import googleapiclient.http
16
+
17
+ from google.generativeai import protos
18
+
19
+ # pylint: disable=g-import-not-at-top
20
+ if typing.TYPE_CHECKING:
21
+ import PIL.Image
22
+ import PIL.ImageFile
23
+ import IPython.display
24
+
25
+ _IMPORTED_IMAGE_TYPES = (PIL.Image.Image, IPython.display.Image)
26
+ ImageType = Union["Image", PIL.Image.Image, IPython.display.Image]
27
+ else:
28
+ _IMPORTED_IMAGE_TYPES = ()
29
+ try:
30
+ import PIL.Image
31
+ import PIL.ImageFile
32
+
33
+ _IMPORTED_IMAGE_TYPES = _IMPORTED_IMAGE_TYPES + (PIL.Image.Image,)
34
+ except ImportError:
35
+ PIL = None
36
+
37
+ try:
38
+ import IPython.display
39
+
40
+ IMAGE_TYPES = _IMPORTED_IMAGE_TYPES + (IPython.display.Image,)
41
+ except ImportError:
42
+ IPython = None
43
+
44
+ ImageType = Union["Image", "PIL.Image.Image", "IPython.display.Image"]
45
+ # pylint: enable=g-import-not-at-top
46
+
47
+ __all__ = [
48
+ "Image",
49
+ "GeneratedImage",
50
+ "to_image",
51
+ "check_watermark",
52
+ "CheckWatermarkResult",
53
+ "ImageType",
54
+ "Video",
55
+ "GeneratedVideo",
56
+ ]
57
+
58
+
59
+ def _pil_to_blob(image: PIL.Image.Image) -> protos.Blob:
60
+ # If the image is a local file, return a file-based blob without any modification.
61
+ # Otherwise, return a lossless WebP blob (same quality with optimized size).
62
+ def file_blob(image: PIL.Image.Image) -> Union[protos.Blob, None]:
63
+ if not isinstance(image, PIL.ImageFile.ImageFile) or image.filename is None:
64
+ return None
65
+ filename = str(image.filename)
66
+ if not pathlib.Path(filename).is_file():
67
+ return None
68
+
69
+ mime_type = image.get_format_mimetype()
70
+ image_bytes = pathlib.Path(filename).read_bytes()
71
+
72
+ return protos.Blob(mime_type=mime_type, data=image_bytes)
73
+
74
+ def webp_blob(image: PIL.Image.Image) -> protos.Blob:
75
+ # Reference: https://pillow.readthedocs.io/en/stable/handbook/image-file-formats.html#webp
76
+ image_io = io.BytesIO()
77
+ image.save(image_io, format="webp", lossless=True)
78
+ image_io.seek(0)
79
+
80
+ mime_type = "image/webp"
81
+ image_bytes = image_io.read()
82
+
83
+ return protos.Blob(mime_type=mime_type, data=image_bytes)
84
+
85
+ return file_blob(image) or webp_blob(image)
86
+
87
+
88
+ def image_to_blob(image: ImageType) -> protos.Blob:
89
+ if PIL is not None:
90
+ if isinstance(image, PIL.Image.Image):
91
+ return _pil_to_blob(image)
92
+
93
+ if IPython is not None:
94
+ if isinstance(image, IPython.display.Image):
95
+ name = image.filename
96
+ if name is None:
97
+ raise ValueError(
98
+ "Conversion failed. The `IPython.display.Image` can only be converted if "
99
+ "it is constructed from a local file. Please ensure you are using the format: Image(filename='...')."
100
+ )
101
+ mime_type, _ = mimetypes.guess_type(name)
102
+ if mime_type is None:
103
+ mime_type = "image/unknown"
104
+
105
+ return protos.Blob(mime_type=mime_type, data=image.data)
106
+
107
+ if isinstance(image, Image):
108
+ return protos.Blob(mime_type=image._mime_type, data=image._image_bytes)
109
+
110
+ raise TypeError(
111
+ "Image conversion failed. The input was expected to be of type `Image` "
112
+ "(either `PIL.Image.Image` or `IPython.display.Image`).\n"
113
+ f"However, received an object of type: {type(image)}.\n"
114
+ f"Object Value: {image}"
115
+ )
116
+
117
+
118
+ class CheckWatermarkResult:
119
+ def __init__(self, predictions):
120
+ self._predictions = predictions
121
+
122
+ @property
123
+ def decision(self):
124
+ return self._predictions[0]["decision"]
125
+
126
+ def __str__(self):
127
+ return f"CheckWatermarkResult([{{'decision': {self.decision!r}}}])"
128
+
129
+ def __bool__(self):
130
+ decision = self.decision
131
+ if decision == "ACCEPT":
132
+ return True
133
+ elif decision == "REJECT":
134
+ return False
135
+ else:
136
+ raise ValueError("Unrecognized result")
137
+
138
+
139
+ def to_image(img: Union[pathlib.Path, ImageType]) -> Image:
140
+ if isinstance(img, Image):
141
+ pass
142
+ elif isinstance(img, pathlib.Path):
143
+ img = Image.load_from_file(img)
144
+ elif IPython.display is not None and isinstance(img, IPython.display.Image):
145
+ img = Image(image_bytes=img.data)
146
+ elif PIL.Image is not None and isinstance(img, PIL.Image.Image):
147
+ blob = _pil_to_blob(img)
148
+ img = Image(image_bytes=blob.data)
149
+ elif isinstance(img, protos.Blob):
150
+ img = Image(image_bytes=img.data)
151
+ else:
152
+ raise TypeError(
153
+ f"Not implemented: Could not convert a {type(img)} into `Image`\n {img=}"
154
+ )
155
+
156
+ return img
157
+
158
+
159
+ class Image:
160
+ """Image."""
161
+
162
+ __module__ = "vertexai.vision_models"
163
+
164
+ _loaded_bytes: Optional[bytes] = None
165
+ _loaded_image: Optional["PIL.Image.Image"] = None
166
+
167
+ def __init__(
168
+ self,
169
+ image_bytes: Optional[bytes],
170
+ ):
171
+ """Creates an `Image` object.
172
+
173
+ Args:
174
+ image_bytes: Image file bytes. Image can be in PNG or JPEG format.
175
+ """
176
+ self._image_bytes = image_bytes
177
+
178
+ @staticmethod
179
+ def load_from_file(location: os.PathLike) -> "Image":
180
+ """Loads image from local file.
181
+
182
+ Args:
183
+ location: Local path from where to load
184
+ the image.
185
+
186
+ Returns:
187
+ Loaded image as an `Image` object.
188
+ """
189
+ # Load image from local path
190
+ image_bytes = pathlib.Path(location).read_bytes()
191
+ image = Image(image_bytes=image_bytes)
192
+ return image
193
+
194
+ @property
195
+ def _image_bytes(self) -> bytes:
196
+ return self._loaded_bytes
197
+
198
+ @_image_bytes.setter
199
+ def _image_bytes(self, value: bytes):
200
+ self._loaded_bytes = value
201
+
202
+ @property
203
+ def _pil_image(self) -> "PIL.Image.Image": # type: ignore
204
+ if self._loaded_image is None:
205
+ if not PIL:
206
+ raise RuntimeError(
207
+ "The PIL module is not available. Please install the Pillow package."
208
+ )
209
+ self._loaded_image = PIL.Image.open(io.BytesIO(self._image_bytes))
210
+ return self._loaded_image
211
+
212
+ @property
213
+ def _size(self):
214
+ return self._pil_image.size
215
+
216
+ @property
217
+ def _mime_type(self) -> str:
218
+ """Returns the MIME type of the image."""
219
+ import PIL
220
+
221
+ return PIL.Image.MIME.get(self._pil_image.format, "image/jpeg")
222
+
223
+ def show(self):
224
+ """Shows the image.
225
+
226
+ This method only works when in a notebook environment.
227
+ """
228
+ if PIL and IPython:
229
+ IPython.display.display(self._pil_image)
230
+
231
+ def save(self, location: str):
232
+ """Saves image to a file.
233
+
234
+ Args:
235
+ location: Local path where to save the image.
236
+ """
237
+ pathlib.Path(location).write_bytes(self._image_bytes)
238
+
239
+ def _as_base64_string(self) -> str:
240
+ """Encodes image using the base64 encoding.
241
+
242
+ Returns:
243
+ Base64 encoding of the image as a string.
244
+ """
245
+ # ! b64encode returns `bytes` object, not `str`.
246
+ # We need to convert `bytes` to `str`, otherwise we get service error:
247
+ # "received initial metadata size exceeds limit"
248
+ return base64.b64encode(self._image_bytes).decode("ascii")
249
+
250
+ def _repr_png_(self):
251
+ return self._pil_image._repr_png_() # type:ignore
252
+
253
+
254
+ IMAGE_TYPES = _IMPORTED_IMAGE_TYPES + (Image,)
255
+
256
+
257
+ _EXIF_USER_COMMENT_TAG_IDX = 0x9286
258
+ _IMAGE_GENERATION_PARAMETERS_EXIF_KEY = (
259
+ "google.cloud.vertexai.image_generation.image_generation_parameters"
260
+ )
261
+
262
+
263
+ class GeneratedImage(Image):
264
+ """Generated image."""
265
+
266
+ __module__ = "google.generativeai"
267
+
268
+ def __init__(
269
+ self,
270
+ image_bytes: Optional[bytes],
271
+ generation_parameters: Dict[str, Any],
272
+ ):
273
+ """Creates a `GeneratedImage` object.
274
+
275
+ Args:
276
+ image_bytes: Image file bytes. Image can be in PNG or JPEG format.
277
+ generation_parameters: Image generation parameter values.
278
+ """
279
+ super().__init__(image_bytes=image_bytes)
280
+ self._generation_parameters = generation_parameters
281
+
282
+ @property
283
+ def generation_parameters(self):
284
+ """Image generation parameters as a dictionary."""
285
+ return self._generation_parameters
286
+
287
+ @staticmethod
288
+ def load_from_file(location: os.PathLike) -> "GeneratedImage":
289
+ """Loads image from file.
290
+
291
+ Args:
292
+ location: Local path from where to load the image.
293
+
294
+ Returns:
295
+ Loaded image as a `GeneratedImage` object.
296
+ """
297
+ base_image = Image.load_from_file(location=location)
298
+ exif = base_image._pil_image.getexif() # pylint: disable=protected-access
299
+ exif_comment_dict = json.loads(exif[_EXIF_USER_COMMENT_TAG_IDX])
300
+ generation_parameters = exif_comment_dict[_IMAGE_GENERATION_PARAMETERS_EXIF_KEY]
301
+ return GeneratedImage(
302
+ image_bytes=base_image._image_bytes, # pylint: disable=protected-access
303
+ generation_parameters=generation_parameters,
304
+ )
305
+
306
+ def save(self, location: str, include_generation_parameters: bool = True):
307
+ """Saves image to a file.
308
+
309
+ Args:
310
+ location: Local path where to save the image.
311
+ include_generation_parameters: Whether to include the image
312
+ generation parameters in the image's EXIF metadata.
313
+ """
314
+ if include_generation_parameters:
315
+ if not self._generation_parameters:
316
+ raise ValueError("Image does not have generation parameters.")
317
+ if not PIL:
318
+ raise ValueError("The PIL module is required for saving generation parameters.")
319
+
320
+ exif = self._pil_image.getexif()
321
+ exif[_EXIF_USER_COMMENT_TAG_IDX] = json.dumps(
322
+ {_IMAGE_GENERATION_PARAMETERS_EXIF_KEY: self._generation_parameters}
323
+ )
324
+ self._pil_image.save(location, exif=exif)
325
+ else:
326
+ super().save(location=location)
327
+
328
+
329
+ class Video:
330
+ """Video."""
331
+
332
+ _loaded_bytes: Optional[bytes] = None
333
+
334
+ def __init__(
335
+ self,
336
+ *,
337
+ video_bytes: Optional[bytes] = None,
338
+ mime_type: Optional[str] = None,
339
+ ):
340
+ """Creates a `Video` object.
341
+ Args:
342
+ video_bytes: Video file bytes. Video can be in AVI, FLV, MKV, MOV,
343
+ MP4, MPEG, MPG, WEBM, and WMV formats.
344
+ """
345
+ self._video_bytes = video_bytes
346
+ self._mime_type = mime_type
347
+
348
+ def _ipython_display_(self):
349
+ if IPython.display is None:
350
+ return
351
+
352
+ IPython.display.display(
353
+ IPython.display.Video(data=self._video_bytes, mimetype=self._mime_type, embed=True)
354
+ )
355
+
356
+ @staticmethod
357
+ def load_from_file(location: str) -> "Video":
358
+ """Loads video from local file.
359
+ Args:
360
+ location: Local path from where to load the video.
361
+ Returns:
362
+ Loaded video as an `Video` object.
363
+ """
364
+ # Load video from local path
365
+ video_bytes = pathlib.Path(location).read_bytes()
366
+ mimetypes.guess_type(location)
367
+ video = Video(video_bytes=video_bytes, mime_type=mimetypes.guess_type(location))
368
+ return video
369
+
370
+ @property
371
+ def _video_bytes(self) -> bytes:
372
+ return self._loaded_bytes
373
+
374
+ @_video_bytes.setter
375
+ def _video_bytes(self, value: bytes):
376
+ self._loaded_bytes = value
377
+
378
+ @property
379
+ def mime_type(self) -> str:
380
+ """Returns the MIME type of the video."""
381
+ return self._mime_type
382
+
383
+ def save(self, location: str):
384
+ """Saves video to a file.
385
+ Args:
386
+ location: Local path where to save the video.
387
+ """
388
+ pathlib.Path(location).write_bytes(self._video_bytes)
389
+
390
+ def _as_base64_string(self) -> str:
391
+ """Encodes video using the base64 encoding.
392
+ Returns:
393
+ Base64 encoding of the video as a string.
394
+ """
395
+ # ! b64encode returns `bytes` object, not `str`.
396
+ # We need to convert `bytes` to `str`, otherwise we get service error:
397
+ # "received initial metadata size exceeds limit"
398
+ return base64.b64encode(self._video_bytes).decode("ascii")
399
+
400
+
401
+ _thread_local = threading.local()
402
+ _thread_local.http = httplib2.Http()
403
+
404
+
405
+ class GeneratedVideo(Video):
406
+ def __init__(self, *, uri):
407
+ self._uri = uri
408
+ self._loaded_bytes = None
409
+
410
+ @property
411
+ def _mime_type(self):
412
+ return "video/mp4"
413
+
414
+ @property
415
+ def _video_bytes(self) -> bytes:
416
+ if self._loaded_bytes is None:
417
+ self._loaded_bytes = self.download()
418
+ return self._loaded_bytes
419
+
420
+ @_video_bytes.setter
421
+ def _video_bytes(self, value: bytes):
422
+ self._loaded_bytes = value
423
+
424
+ def download(self):
425
+ api_key = client._client_manager.client_config["client_options"].api_key
426
+
427
+ request = googleapiclient.http.HttpRequest(
428
+ http=_thread_local.http,
429
+ postproc=googleapiclient.model.MediaModel().response,
430
+ uri=f"{self._uri}&key={api_key}",
431
+ method="GET",
432
+ headers={
433
+ "accept": "*/*",
434
+ "accept-encoding": "gzip, deflate",
435
+ "user-agent": "(gzip)",
436
+ "x-goog-api-client": "gdcl/2.151.0 gl-python/3.12.7",
437
+ },
438
+ )
439
+ result = request.execute()
440
+ return result
.venv/lib/python3.11/site-packages/google/generativeai/types/model_types.py ADDED
@@ -0,0 +1,390 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright 2023 Google LLC
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Type definitions for the models service."""
16
+ from __future__ import annotations
17
+
18
+ from collections.abc import Mapping
19
+ import csv
20
+ import dataclasses
21
+ import datetime
22
+ import json
23
+ import pathlib
24
+ import re
25
+
26
+ from typing import Any, Iterable, Union
27
+
28
+ import urllib.request
29
+ from typing_extensions import TypedDict
30
+
31
+ from google.generativeai import protos
32
+ from google.generativeai.types import permission_types
33
+ from google.generativeai import string_utils
34
+
35
+
36
+ __all__ = [
37
+ "Model",
38
+ "ModelNameOptions",
39
+ "AnyModelNameOptions",
40
+ "BaseModelNameOptions",
41
+ "TunedModelNameOptions",
42
+ "ModelsIterable",
43
+ "TunedModel",
44
+ "TunedModelState",
45
+ ]
46
+
47
+ TunedModelState = protos.TunedModel.State
48
+
49
+ TunedModelStateOptions = Union[None, str, int, TunedModelState]
50
+
51
+ _TUNED_MODEL_VALID_NAME = r"[a-z](([a-z0-9-]{0,61}[a-z0-9])?)$"
52
+ TUNED_MODEL_NAME_ERROR_MSG = """The `name` must consist of alphanumeric characters (or -) and be at most 63 characters; The name you entered:
53
+ \tlen(name)== {length}
54
+ \tname={name}
55
+ """
56
+
57
+
58
+ def valid_tuned_model_name(name: str) -> bool:
59
+ return re.match(_TUNED_MODEL_VALID_NAME, name) is not None
60
+
61
+
62
+ # fmt: off
63
+ _TUNED_MODEL_STATES: dict[TunedModelStateOptions, TunedModelState] = {
64
+ TunedModelState.ACTIVE: TunedModelState.ACTIVE,
65
+ int(TunedModelState.ACTIVE): TunedModelState.ACTIVE,
66
+ "active": TunedModelState.ACTIVE,
67
+
68
+ TunedModelState.CREATING: TunedModelState.CREATING,
69
+ int(TunedModelState.CREATING): TunedModelState.CREATING,
70
+ "creating": TunedModelState.CREATING,
71
+
72
+ TunedModelState.FAILED: TunedModelState.FAILED,
73
+ int(TunedModelState.FAILED): TunedModelState.FAILED,
74
+ "failed": TunedModelState.FAILED,
75
+
76
+ TunedModelState.STATE_UNSPECIFIED: TunedModelState.STATE_UNSPECIFIED,
77
+ int(TunedModelState.STATE_UNSPECIFIED): TunedModelState.STATE_UNSPECIFIED,
78
+ "state_unspecified": TunedModelState.STATE_UNSPECIFIED,
79
+ "unspecified": TunedModelState.STATE_UNSPECIFIED,
80
+ None: TunedModelState.STATE_UNSPECIFIED,
81
+ }
82
+ # fmt: on
83
+
84
+
85
+ def to_tuned_model_state(x: TunedModelStateOptions) -> TunedModelState:
86
+ if isinstance(x, str):
87
+ x = x.lower()
88
+ return _TUNED_MODEL_STATES[x]
89
+
90
+
91
+ @string_utils.prettyprint
92
+ @dataclasses.dataclass
93
+ class Model:
94
+ """A dataclass representation of a `protos.Model`.
95
+
96
+ Attributes:
97
+ name: The resource name of the `Model`. Format: `models/{model}` with a `{model}` naming
98
+ convention of: "{base_model_id}-{version}". For example: `models/chat-bison-001`.
99
+ base_model_id: The base name of the model. For example: `chat-bison`.
100
+ version: The major version number of the model. For example: `001`.
101
+ display_name: The human-readable name of the model. E.g. `"Chat Bison"`. The name can be up
102
+ to 128 characters long and can consist of any UTF-8 characters.
103
+ description: A short description of the model.
104
+ input_token_limit: Maximum number of input tokens allowed for this model.
105
+ output_token_limit: Maximum number of output tokens available for this model.
106
+ supported_generation_methods: lists which methods are supported by the model. The method
107
+ names are defined as Pascal case strings, such as `generateMessage` which correspond to
108
+ API methods.
109
+ """
110
+
111
+ name: str
112
+ base_model_id: str
113
+ version: str
114
+ display_name: str
115
+ description: str
116
+ input_token_limit: int
117
+ output_token_limit: int
118
+ supported_generation_methods: list[str]
119
+ temperature: float | None = None
120
+ max_temperature: float | None = None
121
+ top_p: float | None = None
122
+ top_k: int | None = None
123
+
124
+
125
+ def _fix_microseconds(match):
126
+ # microseconds needs exactly 6 digits
127
+ fraction = float(match.group(0))
128
+ return f".{int(round(fraction*1e6)):06d}"
129
+
130
+
131
+ def idecode_time(parent: dict["str", Any], name: str):
132
+ time = parent.pop(name, None)
133
+ if time is not None:
134
+ if "." in time:
135
+ time = re.sub(r"\.\d+", _fix_microseconds, time)
136
+ dt = datetime.datetime.strptime(time, "%Y-%m-%dT%H:%M:%S.%fZ")
137
+ else:
138
+ dt = datetime.datetime.strptime(time, "%Y-%m-%dT%H:%M:%SZ")
139
+
140
+ dt = dt.replace(tzinfo=datetime.timezone.utc)
141
+ parent[name] = dt
142
+
143
+
144
+ def decode_tuned_model(tuned_model: protos.TunedModel | dict["str", Any]) -> TunedModel:
145
+ if isinstance(tuned_model, protos.TunedModel):
146
+ tuned_model = type(tuned_model).to_dict(
147
+ tuned_model, including_default_value_fields=False
148
+ ) # pytype: disable=attribute-error
149
+ tuned_model["state"] = to_tuned_model_state(tuned_model.pop("state", None))
150
+
151
+ base_model = tuned_model.pop("base_model", None)
152
+ tuned_model_source = tuned_model.pop("tuned_model_source", None)
153
+ if base_model is not None:
154
+ tuned_model["base_model"] = base_model
155
+ tuned_model["source_model"] = base_model
156
+ elif tuned_model_source is not None:
157
+ tuned_model["base_model"] = tuned_model_source["base_model"]
158
+ tuned_model["source_model"] = tuned_model_source["tuned_model"]
159
+
160
+ idecode_time(tuned_model, "create_time")
161
+ idecode_time(tuned_model, "update_time")
162
+
163
+ task = tuned_model.pop("tuning_task", None)
164
+ if task is not None:
165
+ hype = task.pop("hyperparameters", None)
166
+ if hype is not None:
167
+ hype = Hyperparameters(**hype)
168
+ task["hyperparameters"] = hype
169
+
170
+ idecode_time(task, "start_time")
171
+ idecode_time(task, "complete_time")
172
+
173
+ snapshots = task.pop("snapshots", None)
174
+ if snapshots is not None:
175
+ for snap in snapshots:
176
+ idecode_time(snap, "compute_time")
177
+ task["snapshots"] = snapshots
178
+ task = TuningTask(**task)
179
+ tuned_model["tuning_task"] = task
180
+ return TunedModel(**tuned_model)
181
+
182
+
183
+ @string_utils.prettyprint
184
+ @dataclasses.dataclass
185
+ class TunedModel:
186
+ """A dataclass representation of a `protos.TunedModel`."""
187
+
188
+ name: str | None = None
189
+ source_model: str | None = None
190
+ base_model: str | None = None
191
+ display_name: str = ""
192
+ description: str = ""
193
+ temperature: float | None = None
194
+ top_p: float | None = None
195
+ top_k: float | None = None
196
+ state: TunedModelState = TunedModelState.STATE_UNSPECIFIED
197
+ create_time: datetime.datetime | None = None
198
+ update_time: datetime.datetime | None = None
199
+ tuning_task: TuningTask | None = None
200
+ reader_project_numbers: list[int] | None = None
201
+
202
+ @property
203
+ def permissions(self) -> permission_types.Permissions:
204
+ return permission_types.Permissions(self)
205
+
206
+
207
+ @string_utils.prettyprint
208
+ @dataclasses.dataclass
209
+ class TuningTask:
210
+ start_time: datetime.datetime | None = None
211
+ complete_time: datetime.datetime | None = None
212
+ snapshots: list[TuningSnapshot] = dataclasses.field(default_factory=list)
213
+ hyperparameters: Hyperparameters | None = None
214
+
215
+
216
+ class TuningExampleDict(TypedDict):
217
+ text_input: str
218
+ output: str
219
+
220
+
221
+ TuningExampleOptions = Union[TuningExampleDict, protos.TuningExample, tuple[str, str], list[str]]
222
+
223
+ # TODO(markdaoust): gs:// URLS? File-type argument for files without extension?
224
+ TuningDataOptions = Union[
225
+ pathlib.Path,
226
+ str,
227
+ protos.Dataset,
228
+ Mapping[str, Iterable[str]],
229
+ Iterable[TuningExampleOptions],
230
+ ]
231
+
232
+
233
+ def encode_tuning_data(
234
+ data: TuningDataOptions, input_key="text_input", output_key="output"
235
+ ) -> protos.Dataset:
236
+ if isinstance(data, protos.Dataset):
237
+ return data
238
+
239
+ if isinstance(data, str):
240
+ # Strings are either URLs or system paths.
241
+ if re.match(r"^\w+://\S+$", data):
242
+ data = _normalize_url(data)
243
+ else:
244
+ # Normalize system paths to use pathlib
245
+ data = pathlib.Path(data)
246
+
247
+ if isinstance(data, (str, pathlib.Path)):
248
+ if isinstance(data, str):
249
+ f = urllib.request.urlopen(data)
250
+ # csv needs strings, json does not.
251
+ content = (line.decode("utf-8") for line in f)
252
+ else:
253
+ f = data.open("r")
254
+ content = f
255
+
256
+ if str(data).lower().endswith(".json"):
257
+ with f:
258
+ data = json.load(f)
259
+ else:
260
+ with f:
261
+ data = csv.DictReader(content)
262
+ return _convert_iterable(data, input_key, output_key)
263
+
264
+ if hasattr(data, "keys"):
265
+ return _convert_dict(data, input_key, output_key)
266
+ else:
267
+ return _convert_iterable(data, input_key, output_key)
268
+
269
+
270
+ def _normalize_url(url: str) -> str:
271
+ sheet_base = "https://docs.google.com/spreadsheets"
272
+ if url.startswith(sheet_base):
273
+ # Normalize google-sheets URLs to download the csv.
274
+ id_match = re.match(f"{sheet_base}/d/[^/]+", url)
275
+ if id_match is None:
276
+ raise ValueError("Incomplete Google Sheets URL: {data}")
277
+
278
+ if tab_match := re.search(r"gid=(\d+)", url):
279
+ tab_param = f"&gid={tab_match.group(1)}"
280
+ else:
281
+ tab_param = ""
282
+
283
+ url = f"{id_match.group(0)}/export?format=csv{tab_param}"
284
+
285
+ return url
286
+
287
+
288
+ def _convert_dict(data, input_key, output_key):
289
+ new_data = list()
290
+
291
+ try:
292
+ inputs = data[input_key]
293
+ except KeyError:
294
+ raise KeyError(
295
+ f"Invalid key: The input key '{input_key}' does not exist in the data. "
296
+ f"Available keys are: {sorted(data.keys())}."
297
+ )
298
+
299
+ try:
300
+ outputs = data[output_key]
301
+ except KeyError:
302
+ raise KeyError(
303
+ f"Invalid key: The output key '{output_key}' does not exist in the data. "
304
+ f"Available keys are: {sorted(data.keys())}."
305
+ )
306
+
307
+ for i, o in zip(inputs, outputs):
308
+ new_data.append(protos.TuningExample({"text_input": str(i), "output": str(o)}))
309
+ return protos.Dataset(examples=protos.TuningExamples(examples=new_data))
310
+
311
+
312
+ def _convert_iterable(data, input_key, output_key):
313
+ new_data = list()
314
+ for example in data:
315
+ example = encode_tuning_example(example, input_key, output_key)
316
+ new_data.append(example)
317
+ return protos.Dataset(examples=protos.TuningExamples(examples=new_data))
318
+
319
+
320
+ def encode_tuning_example(example: TuningExampleOptions, input_key, output_key):
321
+ if isinstance(example, protos.TuningExample):
322
+ return example
323
+ elif isinstance(example, (tuple, list)):
324
+ a, b = example
325
+ example = protos.TuningExample(text_input=a, output=b)
326
+ else: # dict
327
+ example = protos.TuningExample(text_input=example[input_key], output=example[output_key])
328
+ return example
329
+
330
+
331
+ @string_utils.prettyprint
332
+ @dataclasses.dataclass
333
+ class TuningSnapshot:
334
+ step: int
335
+ epoch: int
336
+ mean_score: float
337
+ compute_time: datetime.datetime
338
+
339
+
340
+ @string_utils.prettyprint
341
+ @dataclasses.dataclass
342
+ class Hyperparameters:
343
+ epoch_count: int = 0
344
+ batch_size: int = 0
345
+ learning_rate: float = 0.0
346
+
347
+
348
+ BaseModelNameOptions = Union[str, Model, protos.Model]
349
+ TunedModelNameOptions = Union[str, TunedModel, protos.TunedModel]
350
+ AnyModelNameOptions = Union[str, Model, protos.Model, TunedModel, protos.TunedModel]
351
+ ModelNameOptions = AnyModelNameOptions
352
+
353
+
354
+ def make_model_name(name: AnyModelNameOptions):
355
+ if isinstance(name, (Model, protos.Model, TunedModel, protos.TunedModel)):
356
+ name = name.name # pytype: disable=attribute-error
357
+ elif isinstance(name, str):
358
+ name = name
359
+ else:
360
+ raise TypeError(
361
+ "Invalid input type. Expected one of the following types: `str`, `Model`, or `TunedModel`."
362
+ )
363
+
364
+ if not (name.startswith("models/") or name.startswith("tunedModels/")):
365
+ raise ValueError(
366
+ f"Invalid model name: '{name}'. Model names should start with 'models/' or 'tunedModels/'."
367
+ )
368
+
369
+ return name
370
+
371
+
372
+ ModelsIterable = Iterable[Model]
373
+ TunedModelsIterable = Iterable[TunedModel]
374
+
375
+
376
+ @string_utils.prettyprint
377
+ @dataclasses.dataclass
378
+ class TokenCount:
379
+ """A dataclass representation of a `protos.TokenCountResponse`.
380
+
381
+ Attributes:
382
+ token_count: The number of tokens returned by the model's tokenizer for the `input_text`.
383
+ token_count_limit:
384
+ """
385
+
386
+ token_count: int
387
+ token_count_limit: int
388
+
389
+ def over_limit(self):
390
+ return self.token_count > self.token_count_limit
.venv/lib/python3.11/site-packages/google/generativeai/types/palm_safety_types.py ADDED
@@ -0,0 +1,286 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright 2023 Google LLC
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ from __future__ import annotations
16
+
17
+ from collections.abc import Mapping
18
+
19
+ import enum
20
+ import typing
21
+ from typing import Dict, Iterable, List, Union
22
+
23
+ from typing_extensions import TypedDict
24
+
25
+
26
+ from google.generativeai import protos
27
+ from google.generativeai import string_utils
28
+
29
+
30
+ __all__ = [
31
+ "HarmCategory",
32
+ "HarmProbability",
33
+ "HarmBlockThreshold",
34
+ "BlockedReason",
35
+ "ContentFilterDict",
36
+ "SafetyRatingDict",
37
+ "SafetySettingDict",
38
+ "SafetyFeedbackDict",
39
+ ]
40
+
41
+ # These are basic python enums, it's okay to expose them
42
+ HarmProbability = protos.SafetyRating.HarmProbability
43
+ HarmBlockThreshold = protos.SafetySetting.HarmBlockThreshold
44
+ BlockedReason = protos.ContentFilter.BlockedReason
45
+
46
+
47
+ class HarmCategory:
48
+ """
49
+ Harm Categories supported by the palm-family models
50
+ """
51
+
52
+ HARM_CATEGORY_UNSPECIFIED = protos.HarmCategory.HARM_CATEGORY_UNSPECIFIED.value
53
+ HARM_CATEGORY_DEROGATORY = protos.HarmCategory.HARM_CATEGORY_DEROGATORY.value
54
+ HARM_CATEGORY_TOXICITY = protos.HarmCategory.HARM_CATEGORY_TOXICITY.value
55
+ HARM_CATEGORY_VIOLENCE = protos.HarmCategory.HARM_CATEGORY_VIOLENCE.value
56
+ HARM_CATEGORY_SEXUAL = protos.HarmCategory.HARM_CATEGORY_SEXUAL.value
57
+ HARM_CATEGORY_MEDICAL = protos.HarmCategory.HARM_CATEGORY_MEDICAL.value
58
+ HARM_CATEGORY_DANGEROUS = protos.HarmCategory.HARM_CATEGORY_DANGEROUS.value
59
+
60
+
61
+ HarmCategoryOptions = Union[str, int, HarmCategory]
62
+
63
+ # fmt: off
64
+ _HARM_CATEGORIES: Dict[HarmCategoryOptions, protos.HarmCategory] = {
65
+ protos.HarmCategory.HARM_CATEGORY_UNSPECIFIED: protos.HarmCategory.HARM_CATEGORY_UNSPECIFIED,
66
+ HarmCategory.HARM_CATEGORY_UNSPECIFIED: protos.HarmCategory.HARM_CATEGORY_UNSPECIFIED,
67
+ 0: protos.HarmCategory.HARM_CATEGORY_UNSPECIFIED,
68
+ "harm_category_unspecified": protos.HarmCategory.HARM_CATEGORY_UNSPECIFIED,
69
+ "unspecified": protos.HarmCategory.HARM_CATEGORY_UNSPECIFIED,
70
+
71
+ protos.HarmCategory.HARM_CATEGORY_DEROGATORY: protos.HarmCategory.HARM_CATEGORY_DEROGATORY,
72
+ HarmCategory.HARM_CATEGORY_DEROGATORY: protos.HarmCategory.HARM_CATEGORY_DEROGATORY,
73
+ 1: protos.HarmCategory.HARM_CATEGORY_DEROGATORY,
74
+ "harm_category_derogatory": protos.HarmCategory.HARM_CATEGORY_DEROGATORY,
75
+ "derogatory": protos.HarmCategory.HARM_CATEGORY_DEROGATORY,
76
+
77
+ protos.HarmCategory.HARM_CATEGORY_TOXICITY: protos.HarmCategory.HARM_CATEGORY_TOXICITY,
78
+ HarmCategory.HARM_CATEGORY_TOXICITY: protos.HarmCategory.HARM_CATEGORY_TOXICITY,
79
+ 2: protos.HarmCategory.HARM_CATEGORY_TOXICITY,
80
+ "harm_category_toxicity": protos.HarmCategory.HARM_CATEGORY_TOXICITY,
81
+ "toxicity": protos.HarmCategory.HARM_CATEGORY_TOXICITY,
82
+ "toxic": protos.HarmCategory.HARM_CATEGORY_TOXICITY,
83
+
84
+ protos.HarmCategory.HARM_CATEGORY_VIOLENCE: protos.HarmCategory.HARM_CATEGORY_VIOLENCE,
85
+ HarmCategory.HARM_CATEGORY_VIOLENCE: protos.HarmCategory.HARM_CATEGORY_VIOLENCE,
86
+ 3: protos.HarmCategory.HARM_CATEGORY_VIOLENCE,
87
+ "harm_category_violence": protos.HarmCategory.HARM_CATEGORY_VIOLENCE,
88
+ "violence": protos.HarmCategory.HARM_CATEGORY_VIOLENCE,
89
+ "violent": protos.HarmCategory.HARM_CATEGORY_VIOLENCE,
90
+
91
+ protos.HarmCategory.HARM_CATEGORY_SEXUAL: protos.HarmCategory.HARM_CATEGORY_SEXUAL,
92
+ HarmCategory.HARM_CATEGORY_SEXUAL: protos.HarmCategory.HARM_CATEGORY_SEXUAL,
93
+ 4: protos.HarmCategory.HARM_CATEGORY_SEXUAL,
94
+ "harm_category_sexual": protos.HarmCategory.HARM_CATEGORY_SEXUAL,
95
+ "sexual": protos.HarmCategory.HARM_CATEGORY_SEXUAL,
96
+ "sex": protos.HarmCategory.HARM_CATEGORY_SEXUAL,
97
+
98
+ protos.HarmCategory.HARM_CATEGORY_MEDICAL: protos.HarmCategory.HARM_CATEGORY_MEDICAL,
99
+ HarmCategory.HARM_CATEGORY_MEDICAL: protos.HarmCategory.HARM_CATEGORY_MEDICAL,
100
+ 5: protos.HarmCategory.HARM_CATEGORY_MEDICAL,
101
+ "harm_category_medical": protos.HarmCategory.HARM_CATEGORY_MEDICAL,
102
+ "medical": protos.HarmCategory.HARM_CATEGORY_MEDICAL,
103
+ "med": protos.HarmCategory.HARM_CATEGORY_MEDICAL,
104
+
105
+ protos.HarmCategory.HARM_CATEGORY_DANGEROUS: protos.HarmCategory.HARM_CATEGORY_DANGEROUS,
106
+ HarmCategory.HARM_CATEGORY_DANGEROUS: protos.HarmCategory.HARM_CATEGORY_DANGEROUS,
107
+ 6: protos.HarmCategory.HARM_CATEGORY_DANGEROUS,
108
+ "harm_category_dangerous": protos.HarmCategory.HARM_CATEGORY_DANGEROUS,
109
+ "dangerous": protos.HarmCategory.HARM_CATEGORY_DANGEROUS,
110
+ "danger": protos.HarmCategory.HARM_CATEGORY_DANGEROUS,
111
+ }
112
+ # fmt: on
113
+
114
+
115
+ def to_harm_category(x: HarmCategoryOptions) -> protos.HarmCategory:
116
+ if isinstance(x, str):
117
+ x = x.lower()
118
+ return _HARM_CATEGORIES[x]
119
+
120
+
121
+ HarmBlockThresholdOptions = Union[str, int, HarmBlockThreshold]
122
+
123
+ # fmt: off
124
+ _BLOCK_THRESHOLDS: Dict[HarmBlockThresholdOptions, HarmBlockThreshold] = {
125
+ HarmBlockThreshold.HARM_BLOCK_THRESHOLD_UNSPECIFIED: HarmBlockThreshold.HARM_BLOCK_THRESHOLD_UNSPECIFIED,
126
+ 0: HarmBlockThreshold.HARM_BLOCK_THRESHOLD_UNSPECIFIED,
127
+ "harm_block_threshold_unspecified": HarmBlockThreshold.HARM_BLOCK_THRESHOLD_UNSPECIFIED,
128
+ "block_threshold_unspecified": HarmBlockThreshold.HARM_BLOCK_THRESHOLD_UNSPECIFIED,
129
+ "unspecified": HarmBlockThreshold.HARM_BLOCK_THRESHOLD_UNSPECIFIED,
130
+
131
+ HarmBlockThreshold.BLOCK_LOW_AND_ABOVE: HarmBlockThreshold.BLOCK_LOW_AND_ABOVE,
132
+ 1: HarmBlockThreshold.BLOCK_LOW_AND_ABOVE,
133
+ "block_low_and_above": HarmBlockThreshold.BLOCK_LOW_AND_ABOVE,
134
+ "low": HarmBlockThreshold.BLOCK_LOW_AND_ABOVE,
135
+
136
+ HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
137
+ 2: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
138
+ "block_medium_and_above": HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
139
+ "medium": HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
140
+ "med": HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
141
+
142
+ HarmBlockThreshold.BLOCK_ONLY_HIGH: HarmBlockThreshold.BLOCK_ONLY_HIGH,
143
+ 3: HarmBlockThreshold.BLOCK_ONLY_HIGH,
144
+ "block_only_high": HarmBlockThreshold.BLOCK_ONLY_HIGH,
145
+ "high": HarmBlockThreshold.BLOCK_ONLY_HIGH,
146
+
147
+ HarmBlockThreshold.BLOCK_NONE: HarmBlockThreshold.BLOCK_NONE,
148
+ 4: HarmBlockThreshold.BLOCK_NONE,
149
+ "block_none": HarmBlockThreshold.BLOCK_NONE,
150
+ }
151
+ # fmt: on
152
+
153
+
154
+ def to_block_threshold(x: HarmBlockThresholdOptions) -> HarmBlockThreshold:
155
+ if isinstance(x, str):
156
+ x = x.lower()
157
+ return _BLOCK_THRESHOLDS[x]
158
+
159
+
160
+ class ContentFilterDict(TypedDict):
161
+ reason: BlockedReason
162
+ message: str
163
+
164
+ __doc__ = string_utils.strip_oneof(protos.ContentFilter.__doc__)
165
+
166
+
167
+ def convert_filters_to_enums(
168
+ filters: Iterable[dict],
169
+ ) -> List[ContentFilterDict]:
170
+ result = []
171
+ for f in filters:
172
+ f = f.copy()
173
+ f["reason"] = BlockedReason(f["reason"])
174
+ f = typing.cast(ContentFilterDict, f)
175
+ result.append(f)
176
+ return result
177
+
178
+
179
+ class SafetyRatingDict(TypedDict):
180
+ category: protos.HarmCategory
181
+ probability: HarmProbability
182
+
183
+ __doc__ = string_utils.strip_oneof(protos.SafetyRating.__doc__)
184
+
185
+
186
+ def convert_rating_to_enum(rating: dict) -> SafetyRatingDict:
187
+ return {
188
+ "category": protos.HarmCategory(rating["category"]),
189
+ "probability": HarmProbability(rating["probability"]),
190
+ }
191
+
192
+
193
+ def convert_ratings_to_enum(ratings: Iterable[dict]) -> List[SafetyRatingDict]:
194
+ result = []
195
+ for r in ratings:
196
+ result.append(convert_rating_to_enum(r))
197
+ return result
198
+
199
+
200
+ class SafetySettingDict(TypedDict):
201
+ category: protos.HarmCategory
202
+ threshold: HarmBlockThreshold
203
+
204
+ __doc__ = string_utils.strip_oneof(protos.SafetySetting.__doc__)
205
+
206
+
207
+ class LooseSafetySettingDict(TypedDict):
208
+ category: HarmCategoryOptions
209
+ threshold: HarmBlockThresholdOptions
210
+
211
+
212
+ EasySafetySetting = Mapping[HarmCategoryOptions, HarmBlockThresholdOptions]
213
+ EasySafetySettingDict = dict[HarmCategoryOptions, HarmBlockThresholdOptions]
214
+
215
+ SafetySettingOptions = Union[EasySafetySetting, Iterable[LooseSafetySettingDict], None]
216
+
217
+
218
+ def to_easy_safety_dict(settings: SafetySettingOptions) -> EasySafetySettingDict:
219
+ if settings is None:
220
+ return {}
221
+ elif isinstance(settings, Mapping):
222
+ return {to_harm_category(key): to_block_threshold(value) for key, value in settings.items()}
223
+ else: # Iterable
224
+ return {
225
+ to_harm_category(d["category"]): to_block_threshold(d["threshold"]) for d in settings
226
+ }
227
+
228
+
229
+ def normalize_safety_settings(
230
+ settings: SafetySettingOptions,
231
+ ) -> list[SafetySettingDict] | None:
232
+ if settings is None:
233
+ return None
234
+ if isinstance(settings, Mapping):
235
+ return [
236
+ {
237
+ "category": to_harm_category(key),
238
+ "threshold": to_block_threshold(value),
239
+ }
240
+ for key, value in settings.items()
241
+ ]
242
+ else:
243
+ return [
244
+ {
245
+ "category": to_harm_category(d["category"]),
246
+ "threshold": to_block_threshold(d["threshold"]),
247
+ }
248
+ for d in settings
249
+ ]
250
+
251
+
252
+ def convert_setting_to_enum(setting: dict) -> SafetySettingDict:
253
+ return {
254
+ "category": protos.HarmCategory(setting["category"]),
255
+ "threshold": HarmBlockThreshold(setting["threshold"]),
256
+ }
257
+
258
+
259
+ class SafetyFeedbackDict(TypedDict):
260
+ rating: SafetyRatingDict
261
+ setting: SafetySettingDict
262
+
263
+ __doc__ = string_utils.strip_oneof(protos.SafetyFeedback.__doc__)
264
+
265
+
266
+ def convert_safety_feedback_to_enums(
267
+ safety_feedback: Iterable[dict],
268
+ ) -> List[SafetyFeedbackDict]:
269
+ result = []
270
+ for sf in safety_feedback:
271
+ result.append(
272
+ {
273
+ "rating": convert_rating_to_enum(sf["rating"]),
274
+ "setting": convert_setting_to_enum(sf["setting"]),
275
+ }
276
+ )
277
+ return result
278
+
279
+
280
+ def convert_candidate_enums(candidates):
281
+ result = []
282
+ for candidate in candidates:
283
+ candidate = candidate.copy()
284
+ candidate["safety_ratings"] = convert_ratings_to_enum(candidate["safety_ratings"])
285
+ result.append(candidate)
286
+ return result
.venv/lib/python3.11/site-packages/google/generativeai/types/safety_types.py ADDED
@@ -0,0 +1,303 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright 2023 Google LLC
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ from __future__ import annotations
16
+
17
+ from collections.abc import Mapping
18
+
19
+ import enum
20
+ import typing
21
+ from typing import Dict, Iterable, List, Union
22
+
23
+ from typing_extensions import TypedDict
24
+
25
+
26
+ from google.generativeai import protos
27
+ from google.generativeai import string_utils
28
+
29
+
30
+ __all__ = [
31
+ "HarmCategory",
32
+ "HarmProbability",
33
+ "HarmBlockThreshold",
34
+ "BlockedReason",
35
+ "ContentFilterDict",
36
+ "SafetyRatingDict",
37
+ "SafetySettingDict",
38
+ "SafetyFeedbackDict",
39
+ ]
40
+
41
+ # These are basic python enums, it's okay to expose them
42
+ HarmProbability = protos.SafetyRating.HarmProbability
43
+ HarmBlockThreshold = protos.SafetySetting.HarmBlockThreshold
44
+ BlockedReason = protos.ContentFilter.BlockedReason
45
+
46
+ import proto
47
+
48
+
49
+ class HarmCategory(proto.Enum):
50
+ """
51
+ Harm Categories supported by the gemini-family model
52
+ """
53
+
54
+ HARM_CATEGORY_UNSPECIFIED = protos.HarmCategory.HARM_CATEGORY_UNSPECIFIED.value
55
+ HARM_CATEGORY_HARASSMENT = protos.HarmCategory.HARM_CATEGORY_HARASSMENT.value
56
+ HARM_CATEGORY_HATE_SPEECH = protos.HarmCategory.HARM_CATEGORY_HATE_SPEECH.value
57
+ HARM_CATEGORY_SEXUALLY_EXPLICIT = protos.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT.value
58
+ HARM_CATEGORY_DANGEROUS_CONTENT = protos.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT.value
59
+
60
+
61
+ HarmCategoryOptions = Union[str, int, HarmCategory]
62
+
63
+ # fmt: off
64
+ _HARM_CATEGORIES: Dict[HarmCategoryOptions, protos.HarmCategory] = {
65
+ protos.HarmCategory.HARM_CATEGORY_UNSPECIFIED: protos.HarmCategory.HARM_CATEGORY_UNSPECIFIED,
66
+ HarmCategory.HARM_CATEGORY_UNSPECIFIED: protos.HarmCategory.HARM_CATEGORY_UNSPECIFIED,
67
+ 0: protos.HarmCategory.HARM_CATEGORY_UNSPECIFIED,
68
+ "harm_category_unspecified": protos.HarmCategory.HARM_CATEGORY_UNSPECIFIED,
69
+ "unspecified": protos.HarmCategory.HARM_CATEGORY_UNSPECIFIED,
70
+
71
+ 7: protos.HarmCategory.HARM_CATEGORY_HARASSMENT,
72
+ protos.HarmCategory.HARM_CATEGORY_HARASSMENT: protos.HarmCategory.HARM_CATEGORY_HARASSMENT,
73
+ HarmCategory.HARM_CATEGORY_HARASSMENT: protos.HarmCategory.HARM_CATEGORY_HARASSMENT,
74
+ "harm_category_harassment": protos.HarmCategory.HARM_CATEGORY_HARASSMENT,
75
+ "harassment": protos.HarmCategory.HARM_CATEGORY_HARASSMENT,
76
+
77
+ 8: protos.HarmCategory.HARM_CATEGORY_HATE_SPEECH,
78
+ protos.HarmCategory.HARM_CATEGORY_HATE_SPEECH: protos.HarmCategory.HARM_CATEGORY_HATE_SPEECH,
79
+ HarmCategory.HARM_CATEGORY_HATE_SPEECH: protos.HarmCategory.HARM_CATEGORY_HATE_SPEECH,
80
+ 'harm_category_hate_speech': protos.HarmCategory.HARM_CATEGORY_HATE_SPEECH,
81
+ 'hate_speech': protos.HarmCategory.HARM_CATEGORY_HATE_SPEECH,
82
+ 'hate': protos.HarmCategory.HARM_CATEGORY_HATE_SPEECH,
83
+
84
+ 9: protos.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,
85
+ protos.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: protos.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,
86
+ HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: protos.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,
87
+ "harm_category_sexually_explicit": protos.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,
88
+ "harm_category_sexual": protos.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,
89
+ "sexually_explicit": protos.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,
90
+ "sexual": protos.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,
91
+ "sex": protos.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,
92
+
93
+ 10: protos.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
94
+ protos.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: protos.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
95
+ HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: protos.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
96
+ "harm_category_dangerous_content": protos.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
97
+ "harm_category_dangerous": protos.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
98
+ "dangerous": protos.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
99
+ "danger": protos.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
100
+ }
101
+ # fmt: on
102
+
103
+
104
+ def to_harm_category(x: HarmCategoryOptions) -> protos.HarmCategory:
105
+ if isinstance(x, str):
106
+ x = x.lower()
107
+ return _HARM_CATEGORIES[x]
108
+
109
+
110
+ HarmBlockThresholdOptions = Union[str, int, HarmBlockThreshold]
111
+
112
+ # fmt: off
113
+ _BLOCK_THRESHOLDS: Dict[HarmBlockThresholdOptions, HarmBlockThreshold] = {
114
+ HarmBlockThreshold.HARM_BLOCK_THRESHOLD_UNSPECIFIED: HarmBlockThreshold.HARM_BLOCK_THRESHOLD_UNSPECIFIED,
115
+ 0: HarmBlockThreshold.HARM_BLOCK_THRESHOLD_UNSPECIFIED,
116
+ "harm_block_threshold_unspecified": HarmBlockThreshold.HARM_BLOCK_THRESHOLD_UNSPECIFIED,
117
+ "block_threshold_unspecified": HarmBlockThreshold.HARM_BLOCK_THRESHOLD_UNSPECIFIED,
118
+ "unspecified": HarmBlockThreshold.HARM_BLOCK_THRESHOLD_UNSPECIFIED,
119
+
120
+ HarmBlockThreshold.BLOCK_LOW_AND_ABOVE: HarmBlockThreshold.BLOCK_LOW_AND_ABOVE,
121
+ 1: HarmBlockThreshold.BLOCK_LOW_AND_ABOVE,
122
+ "block_low_and_above": HarmBlockThreshold.BLOCK_LOW_AND_ABOVE,
123
+ "low": HarmBlockThreshold.BLOCK_LOW_AND_ABOVE,
124
+
125
+ HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
126
+ 2: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
127
+ "block_medium_and_above": HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
128
+ "medium": HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
129
+ "med": HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
130
+
131
+ HarmBlockThreshold.BLOCK_ONLY_HIGH: HarmBlockThreshold.BLOCK_ONLY_HIGH,
132
+ 3: HarmBlockThreshold.BLOCK_ONLY_HIGH,
133
+ "block_only_high": HarmBlockThreshold.BLOCK_ONLY_HIGH,
134
+ "high": HarmBlockThreshold.BLOCK_ONLY_HIGH,
135
+
136
+ HarmBlockThreshold.BLOCK_NONE: HarmBlockThreshold.BLOCK_NONE,
137
+ 4: HarmBlockThreshold.BLOCK_NONE,
138
+ "block_none": HarmBlockThreshold.BLOCK_NONE,
139
+ }
140
+ # fmt: on
141
+
142
+
143
+ def to_block_threshold(x: HarmBlockThresholdOptions) -> HarmBlockThreshold:
144
+ if isinstance(x, str):
145
+ x = x.lower()
146
+ return _BLOCK_THRESHOLDS[x]
147
+
148
+
149
+ class ContentFilterDict(TypedDict):
150
+ reason: BlockedReason
151
+ message: str
152
+
153
+ __doc__ = string_utils.strip_oneof(protos.ContentFilter.__doc__)
154
+
155
+
156
+ def convert_filters_to_enums(
157
+ filters: Iterable[dict],
158
+ ) -> List[ContentFilterDict]:
159
+ result = []
160
+ for f in filters:
161
+ f = f.copy()
162
+ f["reason"] = BlockedReason(f["reason"])
163
+ f = typing.cast(ContentFilterDict, f)
164
+ result.append(f)
165
+ return result
166
+
167
+
168
+ class SafetyRatingDict(TypedDict):
169
+ category: protos.HarmCategory
170
+ probability: HarmProbability
171
+
172
+ __doc__ = string_utils.strip_oneof(protos.SafetyRating.__doc__)
173
+
174
+
175
+ def convert_rating_to_enum(rating: dict) -> SafetyRatingDict:
176
+ return {
177
+ "category": protos.HarmCategory(rating["category"]),
178
+ "probability": HarmProbability(rating["probability"]),
179
+ }
180
+
181
+
182
+ def convert_ratings_to_enum(ratings: Iterable[dict]) -> List[SafetyRatingDict]:
183
+ result = []
184
+ for r in ratings:
185
+ result.append(convert_rating_to_enum(r))
186
+ return result
187
+
188
+
189
+ class SafetySettingDict(TypedDict):
190
+ category: protos.HarmCategory
191
+ threshold: HarmBlockThreshold
192
+
193
+ __doc__ = string_utils.strip_oneof(protos.SafetySetting.__doc__)
194
+
195
+
196
+ class LooseSafetySettingDict(TypedDict):
197
+ category: HarmCategoryOptions
198
+ threshold: HarmBlockThresholdOptions
199
+
200
+
201
+ EasySafetySetting = Mapping[HarmCategoryOptions, HarmBlockThresholdOptions]
202
+ EasySafetySettingDict = dict[HarmCategoryOptions, HarmBlockThresholdOptions]
203
+
204
+ SafetySettingOptions = Union[
205
+ HarmBlockThresholdOptions, EasySafetySetting, Iterable[LooseSafetySettingDict], None
206
+ ]
207
+
208
+
209
+ def _expand_block_threshold(block_threshold: HarmBlockThresholdOptions):
210
+ block_threshold = to_block_threshold(block_threshold)
211
+ hc = set(_HARM_CATEGORIES.values())
212
+ hc.remove(protos.HarmCategory.HARM_CATEGORY_UNSPECIFIED)
213
+ return {category: block_threshold for category in hc}
214
+
215
+
216
+ def to_easy_safety_dict(settings: SafetySettingOptions) -> EasySafetySettingDict:
217
+ if settings is None:
218
+ return {}
219
+
220
+ if isinstance(settings, (int, str, HarmBlockThreshold)):
221
+ settings = _expand_block_threshold(settings)
222
+
223
+ if isinstance(settings, Mapping):
224
+ return {to_harm_category(key): to_block_threshold(value) for key, value in settings.items()}
225
+
226
+ else: # Iterable
227
+ result = {}
228
+ for setting in settings:
229
+ if isinstance(setting, protos.SafetySetting):
230
+ result[to_harm_category(setting.category)] = to_block_threshold(setting.threshold)
231
+ elif isinstance(setting, dict):
232
+ result[to_harm_category(setting["category"])] = to_block_threshold(
233
+ setting["threshold"]
234
+ )
235
+ else:
236
+ raise ValueError(
237
+ f"Could not understand safety setting:\n {type(setting)=}\n {setting=}"
238
+ )
239
+ return result
240
+
241
+
242
+ def normalize_safety_settings(
243
+ settings: SafetySettingOptions,
244
+ ) -> list[SafetySettingDict] | None:
245
+ if settings is None:
246
+ return None
247
+
248
+ if isinstance(settings, (int, str, HarmBlockThreshold)):
249
+ settings = _expand_block_threshold(settings)
250
+
251
+ if isinstance(settings, Mapping):
252
+ return [
253
+ {
254
+ "category": to_harm_category(key),
255
+ "threshold": to_block_threshold(value),
256
+ }
257
+ for key, value in settings.items()
258
+ ]
259
+ else:
260
+ return [
261
+ {
262
+ "category": to_harm_category(d["category"]),
263
+ "threshold": to_block_threshold(d["threshold"]),
264
+ }
265
+ for d in settings
266
+ ]
267
+
268
+
269
+ def convert_setting_to_enum(setting: dict) -> SafetySettingDict:
270
+ return {
271
+ "category": protos.HarmCategory(setting["category"]),
272
+ "threshold": HarmBlockThreshold(setting["threshold"]),
273
+ }
274
+
275
+
276
+ class SafetyFeedbackDict(TypedDict):
277
+ rating: SafetyRatingDict
278
+ setting: SafetySettingDict
279
+
280
+ __doc__ = string_utils.strip_oneof(protos.SafetyFeedback.__doc__)
281
+
282
+
283
+ def convert_safety_feedback_to_enums(
284
+ safety_feedback: Iterable[dict],
285
+ ) -> List[SafetyFeedbackDict]:
286
+ result = []
287
+ for sf in safety_feedback:
288
+ result.append(
289
+ {
290
+ "rating": convert_rating_to_enum(sf["rating"]),
291
+ "setting": convert_setting_to_enum(sf["setting"]),
292
+ }
293
+ )
294
+ return result
295
+
296
+
297
+ def convert_candidate_enums(candidates):
298
+ result = []
299
+ for candidate in candidates:
300
+ candidate = candidate.copy()
301
+ candidate["safety_ratings"] = convert_ratings_to_enum(candidate["safety_ratings"])
302
+ result.append(candidate)
303
+ return result
.venv/lib/python3.11/site-packages/google/generativeai/types/text_types.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright 2023 Google LLC
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ from __future__ import annotations
16
+
17
+ import sys
18
+ import abc
19
+ import dataclasses
20
+ from typing import Any, Dict, List
21
+ from typing_extensions import TypedDict
22
+
23
+ from google.generativeai import string_utils
24
+ from google.generativeai.types import citation_types
25
+
26
+
27
+ class EmbeddingDict(TypedDict):
28
+ embedding: list[float]
29
+
30
+
31
+ class BatchEmbeddingDict(TypedDict):
32
+ embedding: list[list[float]]
.venv/lib/python3.11/site-packages/google/logging/type/__pycache__/http_request_pb2.cpython-311.pyc ADDED
Binary file (2.28 kB). View file
 
.venv/lib/python3.11/site-packages/google/logging/type/__pycache__/log_severity_pb2.cpython-311.pyc ADDED
Binary file (1.94 kB). View file
 
.venv/lib/python3.11/site-packages/google/logging/type/http_request.proto ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright 2024 Google LLC
2
+ //
3
+ // Licensed under the Apache License, Version 2.0 (the "License");
4
+ // you may not use this file except in compliance with the License.
5
+ // You may obtain a copy of the License at
6
+ //
7
+ // http://www.apache.org/licenses/LICENSE-2.0
8
+ //
9
+ // Unless required by applicable law or agreed to in writing, software
10
+ // distributed under the License is distributed on an "AS IS" BASIS,
11
+ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ // See the License for the specific language governing permissions and
13
+ // limitations under the License.
14
+
15
+ syntax = "proto3";
16
+
17
+ package google.logging.type;
18
+
19
+ import "google/protobuf/duration.proto";
20
+
21
+ option csharp_namespace = "Google.Cloud.Logging.Type";
22
+ option go_package = "google.golang.org/genproto/googleapis/logging/type;ltype";
23
+ option java_multiple_files = true;
24
+ option java_outer_classname = "HttpRequestProto";
25
+ option java_package = "com.google.logging.type";
26
+ option php_namespace = "Google\\Cloud\\Logging\\Type";
27
+ option ruby_package = "Google::Cloud::Logging::Type";
28
+
29
+ // A common proto for logging HTTP requests. Only contains semantics
30
+ // defined by the HTTP specification. Product-specific logging
31
+ // information MUST be defined in a separate message.
32
+ message HttpRequest {
33
+ // The request method. Examples: `"GET"`, `"HEAD"`, `"PUT"`, `"POST"`.
34
+ string request_method = 1;
35
+
36
+ // The scheme (http, https), the host name, the path and the query
37
+ // portion of the URL that was requested.
38
+ // Example: `"http://example.com/some/info?color=red"`.
39
+ string request_url = 2;
40
+
41
+ // The size of the HTTP request message in bytes, including the request
42
+ // headers and the request body.
43
+ int64 request_size = 3;
44
+
45
+ // The response code indicating the status of response.
46
+ // Examples: 200, 404.
47
+ int32 status = 4;
48
+
49
+ // The size of the HTTP response message sent back to the client, in bytes,
50
+ // including the response headers and the response body.
51
+ int64 response_size = 5;
52
+
53
+ // The user agent sent by the client. Example:
54
+ // `"Mozilla/4.0 (compatible; MSIE 6.0; Windows 98; Q312461; .NET
55
+ // CLR 1.0.3705)"`.
56
+ string user_agent = 6;
57
+
58
+ // The IP address (IPv4 or IPv6) of the client that issued the HTTP
59
+ // request. This field can include port information. Examples:
60
+ // `"192.168.1.1"`, `"10.0.0.1:80"`, `"FE80::0202:B3FF:FE1E:8329"`.
61
+ string remote_ip = 7;
62
+
63
+ // The IP address (IPv4 or IPv6) of the origin server that the request was
64
+ // sent to. This field can include port information. Examples:
65
+ // `"192.168.1.1"`, `"10.0.0.1:80"`, `"FE80::0202:B3FF:FE1E:8329"`.
66
+ string server_ip = 13;
67
+
68
+ // The referer URL of the request, as defined in
69
+ // [HTTP/1.1 Header Field
70
+ // Definitions](https://datatracker.ietf.org/doc/html/rfc2616#section-14.36).
71
+ string referer = 8;
72
+
73
+ // The request processing latency on the server, from the time the request was
74
+ // received until the response was sent.
75
+ google.protobuf.Duration latency = 14;
76
+
77
+ // Whether or not a cache lookup was attempted.
78
+ bool cache_lookup = 11;
79
+
80
+ // Whether or not an entity was served from cache
81
+ // (with or without validation).
82
+ bool cache_hit = 9;
83
+
84
+ // Whether or not the response was validated with the origin server before
85
+ // being served from cache. This field is only meaningful if `cache_hit` is
86
+ // True.
87
+ bool cache_validated_with_origin_server = 10;
88
+
89
+ // The number of HTTP response bytes inserted into cache. Set only when a
90
+ // cache fill was attempted.
91
+ int64 cache_fill_bytes = 12;
92
+
93
+ // Protocol used for the request. Examples: "HTTP/1.1", "HTTP/2", "websocket"
94
+ string protocol = 15;
95
+ }
.venv/lib/python3.11/site-packages/google/logging/type/log_severity.proto ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright 2024 Google LLC
2
+ //
3
+ // Licensed under the Apache License, Version 2.0 (the "License");
4
+ // you may not use this file except in compliance with the License.
5
+ // You may obtain a copy of the License at
6
+ //
7
+ // http://www.apache.org/licenses/LICENSE-2.0
8
+ //
9
+ // Unless required by applicable law or agreed to in writing, software
10
+ // distributed under the License is distributed on an "AS IS" BASIS,
11
+ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ // See the License for the specific language governing permissions and
13
+ // limitations under the License.
14
+
15
+ syntax = "proto3";
16
+
17
+ package google.logging.type;
18
+
19
+ option csharp_namespace = "Google.Cloud.Logging.Type";
20
+ option go_package = "google.golang.org/genproto/googleapis/logging/type;ltype";
21
+ option java_multiple_files = true;
22
+ option java_outer_classname = "LogSeverityProto";
23
+ option java_package = "com.google.logging.type";
24
+ option objc_class_prefix = "GLOG";
25
+ option php_namespace = "Google\\Cloud\\Logging\\Type";
26
+ option ruby_package = "Google::Cloud::Logging::Type";
27
+
28
+ // The severity of the event described in a log entry, expressed as one of the
29
+ // standard severity levels listed below. For your reference, the levels are
30
+ // assigned the listed numeric values. The effect of using numeric values other
31
+ // than those listed is undefined.
32
+ //
33
+ // You can filter for log entries by severity. For example, the following
34
+ // filter expression will match log entries with severities `INFO`, `NOTICE`,
35
+ // and `WARNING`:
36
+ //
37
+ // severity > DEBUG AND severity <= WARNING
38
+ //
39
+ // If you are writing log entries, you should map other severity encodings to
40
+ // one of these standard levels. For example, you might map all of Java's FINE,
41
+ // FINER, and FINEST levels to `LogSeverity.DEBUG`. You can preserve the
42
+ // original severity level in the log entry payload if you wish.
43
+ enum LogSeverity {
44
+ // (0) The log entry has no assigned severity level.
45
+ DEFAULT = 0;
46
+
47
+ // (100) Debug or trace information.
48
+ DEBUG = 100;
49
+
50
+ // (200) Routine information, such as ongoing status or performance.
51
+ INFO = 200;
52
+
53
+ // (300) Normal but significant events, such as start up, shut down, or
54
+ // a configuration change.
55
+ NOTICE = 300;
56
+
57
+ // (400) Warning events might cause problems.
58
+ WARNING = 400;
59
+
60
+ // (500) Error events are likely to cause problems.
61
+ ERROR = 500;
62
+
63
+ // (600) Critical events cause more severe problems or outages.
64
+ CRITICAL = 600;
65
+
66
+ // (700) A person must take an action immediately.
67
+ ALERT = 700;
68
+
69
+ // (800) One or more systems are unusable.
70
+ EMERGENCY = 800;
71
+ }
.venv/lib/python3.11/site-packages/google/logging/type/log_severity_pb2.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ # Copyright 2024 Google LLC
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ # Generated by the protocol buffer compiler. DO NOT EDIT!
18
+ # source: google/logging/type/log_severity.proto
19
+ """Generated protocol buffer code."""
20
+ from google.protobuf import descriptor as _descriptor
21
+ from google.protobuf import descriptor_pool as _descriptor_pool
22
+ from google.protobuf import symbol_database as _symbol_database
23
+ from google.protobuf.internal import builder as _builder
24
+
25
+ # @@protoc_insertion_point(imports)
26
+
27
+ _sym_db = _symbol_database.Default()
28
+
29
+
30
+ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
31
+ b"\n&google/logging/type/log_severity.proto\x12\x13google.logging.type*\x82\x01\n\x0bLogSeverity\x12\x0b\n\x07\x44\x45\x46\x41ULT\x10\x00\x12\t\n\x05\x44\x45\x42UG\x10\x64\x12\t\n\x04INFO\x10\xc8\x01\x12\x0b\n\x06NOTICE\x10\xac\x02\x12\x0c\n\x07WARNING\x10\x90\x03\x12\n\n\x05\x45RROR\x10\xf4\x03\x12\r\n\x08\x43RITICAL\x10\xd8\x04\x12\n\n\x05\x41LERT\x10\xbc\x05\x12\x0e\n\tEMERGENCY\x10\xa0\x06\x42\xc5\x01\n\x17\x63om.google.logging.typeB\x10LogSeverityProtoP\x01Z8google.golang.org/genproto/googleapis/logging/type;ltype\xa2\x02\x04GLOG\xaa\x02\x19Google.Cloud.Logging.Type\xca\x02\x19Google\\Cloud\\Logging\\Type\xea\x02\x1cGoogle::Cloud::Logging::Typeb\x06proto3"
32
+ )
33
+
34
+ _globals = globals()
35
+ _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
36
+ _builder.BuildTopDescriptorsAndMessages(
37
+ DESCRIPTOR, "google.logging.type.log_severity_pb2", _globals
38
+ )
39
+ if _descriptor._USE_C_DESCRIPTORS == False:
40
+ DESCRIPTOR._options = None
41
+ DESCRIPTOR._serialized_options = b"\n\027com.google.logging.typeB\020LogSeverityProtoP\001Z8google.golang.org/genproto/googleapis/logging/type;ltype\242\002\004GLOG\252\002\031Google.Cloud.Logging.Type\312\002\031Google\\Cloud\\Logging\\Type\352\002\034Google::Cloud::Logging::Type"
42
+ _globals["_LOGSEVERITY"]._serialized_start = 64
43
+ _globals["_LOGSEVERITY"]._serialized_end = 194
44
+ # @@protoc_insertion_point(module_scope)
.venv/lib/python3.11/site-packages/google/protobuf/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # Protocol Buffers - Google's data interchange format
2
+ # Copyright 2008 Google Inc. All rights reserved.
3
+ #
4
+ # Use of this source code is governed by a BSD-style
5
+ # license that can be found in the LICENSE file or at
6
+ # https://developers.google.com/open-source/licenses/bsd
7
+
8
+ # Copyright 2007 Google Inc. All Rights Reserved.
9
+
10
+ __version__ = '5.29.3'
.venv/lib/python3.11/site-packages/google/protobuf/any.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Protocol Buffers - Google's data interchange format
2
+ # Copyright 2008 Google Inc. All rights reserved.
3
+ #
4
+ # Use of this source code is governed by a BSD-style
5
+ # license that can be found in the LICENSE file or at
6
+ # https://developers.google.com/open-source/licenses/bsd
7
+
8
+ """Contains the Any helper APIs."""
9
+
10
+ from typing import Optional
11
+
12
+ from google.protobuf import descriptor
13
+ from google.protobuf.message import Message
14
+
15
+ from google.protobuf.any_pb2 import Any
16
+
17
+
18
+ def pack(
19
+ msg: Message,
20
+ type_url_prefix: Optional[str] = 'type.googleapis.com/',
21
+ deterministic: Optional[bool] = None,
22
+ ) -> Any:
23
+ any_msg = Any()
24
+ any_msg.Pack(
25
+ msg=msg, type_url_prefix=type_url_prefix, deterministic=deterministic
26
+ )
27
+ return any_msg
28
+
29
+
30
+ def unpack(any_msg: Any, msg: Message) -> bool:
31
+ return any_msg.Unpack(msg=msg)
32
+
33
+
34
+ def type_name(any_msg: Any) -> str:
35
+ return any_msg.TypeName()
36
+
37
+
38
+ def is_type(any_msg: Any, des: descriptor.Descriptor) -> bool:
39
+ return any_msg.Is(des)
.venv/lib/python3.11/site-packages/google/protobuf/any_pb2.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Generated by the protocol buffer compiler. DO NOT EDIT!
3
+ # NO CHECKED-IN PROTOBUF GENCODE
4
+ # source: google/protobuf/any.proto
5
+ # Protobuf Python Version: 5.29.3
6
+ """Generated protocol buffer code."""
7
+ from google.protobuf import descriptor as _descriptor
8
+ from google.protobuf import descriptor_pool as _descriptor_pool
9
+ from google.protobuf import runtime_version as _runtime_version
10
+ from google.protobuf import symbol_database as _symbol_database
11
+ from google.protobuf.internal import builder as _builder
12
+ _runtime_version.ValidateProtobufRuntimeVersion(
13
+ _runtime_version.Domain.PUBLIC,
14
+ 5,
15
+ 29,
16
+ 3,
17
+ '',
18
+ 'google/protobuf/any.proto'
19
+ )
20
+ # @@protoc_insertion_point(imports)
21
+
22
+ _sym_db = _symbol_database.Default()
23
+
24
+
25
+
26
+
27
+ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x19google/protobuf/any.proto\x12\x0fgoogle.protobuf\"6\n\x03\x41ny\x12\x19\n\x08type_url\x18\x01 \x01(\tR\x07typeUrl\x12\x14\n\x05value\x18\x02 \x01(\x0cR\x05valueBv\n\x13\x63om.google.protobufB\x08\x41nyProtoP\x01Z,google.golang.org/protobuf/types/known/anypb\xa2\x02\x03GPB\xaa\x02\x1eGoogle.Protobuf.WellKnownTypesb\x06proto3')
28
+
29
+ _globals = globals()
30
+ _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
31
+ _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'google.protobuf.any_pb2', _globals)
32
+ if not _descriptor._USE_C_DESCRIPTORS:
33
+ _globals['DESCRIPTOR']._loaded_options = None
34
+ _globals['DESCRIPTOR']._serialized_options = b'\n\023com.google.protobufB\010AnyProtoP\001Z,google.golang.org/protobuf/types/known/anypb\242\002\003GPB\252\002\036Google.Protobuf.WellKnownTypes'
35
+ _globals['_ANY']._serialized_start=46
36
+ _globals['_ANY']._serialized_end=100
37
+ # @@protoc_insertion_point(module_scope)
.venv/lib/python3.11/site-packages/google/protobuf/api_pb2.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Generated by the protocol buffer compiler. DO NOT EDIT!
3
+ # NO CHECKED-IN PROTOBUF GENCODE
4
+ # source: google/protobuf/api.proto
5
+ # Protobuf Python Version: 5.29.3
6
+ """Generated protocol buffer code."""
7
+ from google.protobuf import descriptor as _descriptor
8
+ from google.protobuf import descriptor_pool as _descriptor_pool
9
+ from google.protobuf import runtime_version as _runtime_version
10
+ from google.protobuf import symbol_database as _symbol_database
11
+ from google.protobuf.internal import builder as _builder
12
+ _runtime_version.ValidateProtobufRuntimeVersion(
13
+ _runtime_version.Domain.PUBLIC,
14
+ 5,
15
+ 29,
16
+ 3,
17
+ '',
18
+ 'google/protobuf/api.proto'
19
+ )
20
+ # @@protoc_insertion_point(imports)
21
+
22
+ _sym_db = _symbol_database.Default()
23
+
24
+
25
+ from google.protobuf import source_context_pb2 as google_dot_protobuf_dot_source__context__pb2
26
+ from google.protobuf import type_pb2 as google_dot_protobuf_dot_type__pb2
27
+
28
+
29
+ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x19google/protobuf/api.proto\x12\x0fgoogle.protobuf\x1a$google/protobuf/source_context.proto\x1a\x1agoogle/protobuf/type.proto\"\xc1\x02\n\x03\x41pi\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x31\n\x07methods\x18\x02 \x03(\x0b\x32\x17.google.protobuf.MethodR\x07methods\x12\x31\n\x07options\x18\x03 \x03(\x0b\x32\x17.google.protobuf.OptionR\x07options\x12\x18\n\x07version\x18\x04 \x01(\tR\x07version\x12\x45\n\x0esource_context\x18\x05 \x01(\x0b\x32\x1e.google.protobuf.SourceContextR\rsourceContext\x12.\n\x06mixins\x18\x06 \x03(\x0b\x32\x16.google.protobuf.MixinR\x06mixins\x12/\n\x06syntax\x18\x07 \x01(\x0e\x32\x17.google.protobuf.SyntaxR\x06syntax\"\xb2\x02\n\x06Method\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12(\n\x10request_type_url\x18\x02 \x01(\tR\x0erequestTypeUrl\x12+\n\x11request_streaming\x18\x03 \x01(\x08R\x10requestStreaming\x12*\n\x11response_type_url\x18\x04 \x01(\tR\x0fresponseTypeUrl\x12-\n\x12response_streaming\x18\x05 \x01(\x08R\x11responseStreaming\x12\x31\n\x07options\x18\x06 \x03(\x0b\x32\x17.google.protobuf.OptionR\x07options\x12/\n\x06syntax\x18\x07 \x01(\x0e\x32\x17.google.protobuf.SyntaxR\x06syntax\"/\n\x05Mixin\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x12\n\x04root\x18\x02 \x01(\tR\x04rootBv\n\x13\x63om.google.protobufB\x08\x41piProtoP\x01Z,google.golang.org/protobuf/types/known/apipb\xa2\x02\x03GPB\xaa\x02\x1eGoogle.Protobuf.WellKnownTypesb\x06proto3')
30
+
31
+ _globals = globals()
32
+ _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
33
+ _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'google.protobuf.api_pb2', _globals)
34
+ if not _descriptor._USE_C_DESCRIPTORS:
35
+ _globals['DESCRIPTOR']._loaded_options = None
36
+ _globals['DESCRIPTOR']._serialized_options = b'\n\023com.google.protobufB\010ApiProtoP\001Z,google.golang.org/protobuf/types/known/apipb\242\002\003GPB\252\002\036Google.Protobuf.WellKnownTypes'
37
+ _globals['_API']._serialized_start=113
38
+ _globals['_API']._serialized_end=434
39
+ _globals['_METHOD']._serialized_start=437
40
+ _globals['_METHOD']._serialized_end=743
41
+ _globals['_MIXIN']._serialized_start=745
42
+ _globals['_MIXIN']._serialized_end=792
43
+ # @@protoc_insertion_point(module_scope)
.venv/lib/python3.11/site-packages/google/protobuf/descriptor.py ADDED
@@ -0,0 +1,1511 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Protocol Buffers - Google's data interchange format
2
+ # Copyright 2008 Google Inc. All rights reserved.
3
+ #
4
+ # Use of this source code is governed by a BSD-style
5
+ # license that can be found in the LICENSE file or at
6
+ # https://developers.google.com/open-source/licenses/bsd
7
+
8
+ """Descriptors essentially contain exactly the information found in a .proto
9
+ file, in types that make this information accessible in Python.
10
+ """
11
+
12
+ __author__ = 'robinson@google.com (Will Robinson)'
13
+
14
+ import abc
15
+ import binascii
16
+ import os
17
+ import threading
18
+ import warnings
19
+
20
+ from google.protobuf.internal import api_implementation
21
+
22
+ _USE_C_DESCRIPTORS = False
23
+ if api_implementation.Type() != 'python':
24
+ # pylint: disable=protected-access
25
+ _message = api_implementation._c_module
26
+ # TODO: Remove this import after fix api_implementation
27
+ if _message is None:
28
+ from google.protobuf.pyext import _message
29
+ _USE_C_DESCRIPTORS = True
30
+
31
+
32
+ class Error(Exception):
33
+ """Base error for this module."""
34
+
35
+
36
+ class TypeTransformationError(Error):
37
+ """Error transforming between python proto type and corresponding C++ type."""
38
+
39
+
40
+ if _USE_C_DESCRIPTORS:
41
+ # This metaclass allows to override the behavior of code like
42
+ # isinstance(my_descriptor, FieldDescriptor)
43
+ # and make it return True when the descriptor is an instance of the extension
44
+ # type written in C++.
45
+ class DescriptorMetaclass(type):
46
+
47
+ def __instancecheck__(cls, obj):
48
+ if super(DescriptorMetaclass, cls).__instancecheck__(obj):
49
+ return True
50
+ if isinstance(obj, cls._C_DESCRIPTOR_CLASS):
51
+ return True
52
+ return False
53
+ else:
54
+ # The standard metaclass; nothing changes.
55
+ DescriptorMetaclass = abc.ABCMeta
56
+
57
+
58
+ class _Lock(object):
59
+ """Wrapper class of threading.Lock(), which is allowed by 'with'."""
60
+
61
+ def __new__(cls):
62
+ self = object.__new__(cls)
63
+ self._lock = threading.Lock() # pylint: disable=protected-access
64
+ return self
65
+
66
+ def __enter__(self):
67
+ self._lock.acquire()
68
+
69
+ def __exit__(self, exc_type, exc_value, exc_tb):
70
+ self._lock.release()
71
+
72
+
73
+ _lock = threading.Lock()
74
+
75
+
76
+ def _Deprecated(name):
77
+ if _Deprecated.count > 0:
78
+ _Deprecated.count -= 1
79
+ warnings.warn(
80
+ 'Call to deprecated create function %s(). Note: Create unlinked '
81
+ 'descriptors is going to go away. Please use get/find descriptors from '
82
+ 'generated code or query the descriptor_pool.'
83
+ % name,
84
+ category=DeprecationWarning, stacklevel=3)
85
+
86
+ # These must match the values in descriptor.proto, but we can't use them
87
+ # directly because we sometimes need to reference them in feature helpers
88
+ # below *during* the build of descriptor.proto.
89
+ _FEATURESET_MESSAGE_ENCODING_DELIMITED = 2
90
+ _FEATURESET_FIELD_PRESENCE_IMPLICIT = 2
91
+ _FEATURESET_FIELD_PRESENCE_LEGACY_REQUIRED = 3
92
+ _FEATURESET_REPEATED_FIELD_ENCODING_PACKED = 1
93
+ _FEATURESET_ENUM_TYPE_CLOSED = 2
94
+
95
+ # Deprecated warnings will print 100 times at most which should be enough for
96
+ # users to notice and do not cause timeout.
97
+ _Deprecated.count = 100
98
+
99
+
100
+ _internal_create_key = object()
101
+
102
+
103
+ class DescriptorBase(metaclass=DescriptorMetaclass):
104
+
105
+ """Descriptors base class.
106
+
107
+ This class is the base of all descriptor classes. It provides common options
108
+ related functionality.
109
+
110
+ Attributes:
111
+ has_options: True if the descriptor has non-default options. Usually it is
112
+ not necessary to read this -- just call GetOptions() which will happily
113
+ return the default instance. However, it's sometimes useful for
114
+ efficiency, and also useful inside the protobuf implementation to avoid
115
+ some bootstrapping issues.
116
+ file (FileDescriptor): Reference to file info.
117
+ """
118
+
119
+ if _USE_C_DESCRIPTORS:
120
+ # The class, or tuple of classes, that are considered as "virtual
121
+ # subclasses" of this descriptor class.
122
+ _C_DESCRIPTOR_CLASS = ()
123
+
124
+ def __init__(self, file, options, serialized_options, options_class_name):
125
+ """Initialize the descriptor given its options message and the name of the
126
+ class of the options message. The name of the class is required in case
127
+ the options message is None and has to be created.
128
+ """
129
+ self._features = None
130
+ self.file = file
131
+ self._options = options
132
+ self._loaded_options = None
133
+ self._options_class_name = options_class_name
134
+ self._serialized_options = serialized_options
135
+
136
+ # Does this descriptor have non-default options?
137
+ self.has_options = (self._options is not None) or (
138
+ self._serialized_options is not None
139
+ )
140
+
141
+ @property
142
+ @abc.abstractmethod
143
+ def _parent(self):
144
+ pass
145
+
146
+ def _InferLegacyFeatures(self, edition, options, features):
147
+ """Infers features from proto2/proto3 syntax so that editions logic can be used everywhere.
148
+
149
+ Args:
150
+ edition: The edition to infer features for.
151
+ options: The options for this descriptor that are being processed.
152
+ features: The feature set object to modify with inferred features.
153
+ """
154
+ pass
155
+
156
+ def _GetFeatures(self):
157
+ if not self._features:
158
+ self._LazyLoadOptions()
159
+ return self._features
160
+
161
+ def _ResolveFeatures(self, edition, raw_options):
162
+ """Resolves features from the raw options of this descriptor.
163
+
164
+ Args:
165
+ edition: The edition to use for feature defaults.
166
+ raw_options: The options for this descriptor that are being processed.
167
+
168
+ Returns:
169
+ A fully resolved feature set for making runtime decisions.
170
+ """
171
+ # pylint: disable=g-import-not-at-top
172
+ from google.protobuf import descriptor_pb2
173
+
174
+ if self._parent:
175
+ features = descriptor_pb2.FeatureSet()
176
+ features.CopyFrom(self._parent._GetFeatures())
177
+ else:
178
+ features = self.file.pool._CreateDefaultFeatures(edition)
179
+ unresolved = descriptor_pb2.FeatureSet()
180
+ unresolved.CopyFrom(raw_options.features)
181
+ self._InferLegacyFeatures(edition, raw_options, unresolved)
182
+ features.MergeFrom(unresolved)
183
+
184
+ # Use the feature cache to reduce memory bloat.
185
+ return self.file.pool._InternFeatures(features)
186
+
187
+ def _LazyLoadOptions(self):
188
+ """Lazily initializes descriptor options towards the end of the build."""
189
+ if self._loaded_options:
190
+ return
191
+
192
+ # pylint: disable=g-import-not-at-top
193
+ from google.protobuf import descriptor_pb2
194
+
195
+ if not hasattr(descriptor_pb2, self._options_class_name):
196
+ raise RuntimeError(
197
+ 'Unknown options class name %s!' % self._options_class_name
198
+ )
199
+ options_class = getattr(descriptor_pb2, self._options_class_name)
200
+ features = None
201
+ edition = self.file._edition
202
+
203
+ if not self.has_options:
204
+ if not self._features:
205
+ features = self._ResolveFeatures(
206
+ descriptor_pb2.Edition.Value(edition), options_class()
207
+ )
208
+ with _lock:
209
+ self._loaded_options = options_class()
210
+ if not self._features:
211
+ self._features = features
212
+ else:
213
+ if not self._serialized_options:
214
+ options = self._options
215
+ else:
216
+ options = _ParseOptions(options_class(), self._serialized_options)
217
+
218
+ if not self._features:
219
+ features = self._ResolveFeatures(
220
+ descriptor_pb2.Edition.Value(edition), options
221
+ )
222
+ with _lock:
223
+ self._loaded_options = options
224
+ if not self._features:
225
+ self._features = features
226
+ if options.HasField('features'):
227
+ options.ClearField('features')
228
+ if not options.SerializeToString():
229
+ self._loaded_options = options_class()
230
+ self.has_options = False
231
+
232
+ def GetOptions(self):
233
+ """Retrieves descriptor options.
234
+
235
+ Returns:
236
+ The options set on this descriptor.
237
+ """
238
+ if not self._loaded_options:
239
+ self._LazyLoadOptions()
240
+ return self._loaded_options
241
+
242
+
243
+ class _NestedDescriptorBase(DescriptorBase):
244
+ """Common class for descriptors that can be nested."""
245
+
246
+ def __init__(self, options, options_class_name, name, full_name,
247
+ file, containing_type, serialized_start=None,
248
+ serialized_end=None, serialized_options=None):
249
+ """Constructor.
250
+
251
+ Args:
252
+ options: Protocol message options or None to use default message options.
253
+ options_class_name (str): The class name of the above options.
254
+ name (str): Name of this protocol message type.
255
+ full_name (str): Fully-qualified name of this protocol message type, which
256
+ will include protocol "package" name and the name of any enclosing
257
+ types.
258
+ containing_type: if provided, this is a nested descriptor, with this
259
+ descriptor as parent, otherwise None.
260
+ serialized_start: The start index (inclusive) in block in the
261
+ file.serialized_pb that describes this descriptor.
262
+ serialized_end: The end index (exclusive) in block in the
263
+ file.serialized_pb that describes this descriptor.
264
+ serialized_options: Protocol message serialized options or None.
265
+ """
266
+ super(_NestedDescriptorBase, self).__init__(
267
+ file, options, serialized_options, options_class_name
268
+ )
269
+
270
+ self.name = name
271
+ # TODO: Add function to calculate full_name instead of having it in
272
+ # memory?
273
+ self.full_name = full_name
274
+ self.containing_type = containing_type
275
+
276
+ self._serialized_start = serialized_start
277
+ self._serialized_end = serialized_end
278
+
279
+ def CopyToProto(self, proto):
280
+ """Copies this to the matching proto in descriptor_pb2.
281
+
282
+ Args:
283
+ proto: An empty proto instance from descriptor_pb2.
284
+
285
+ Raises:
286
+ Error: If self couldn't be serialized, due to to few constructor
287
+ arguments.
288
+ """
289
+ if (self.file is not None and
290
+ self._serialized_start is not None and
291
+ self._serialized_end is not None):
292
+ proto.ParseFromString(self.file.serialized_pb[
293
+ self._serialized_start:self._serialized_end])
294
+ else:
295
+ raise Error('Descriptor does not contain serialization.')
296
+
297
+
298
+ class Descriptor(_NestedDescriptorBase):
299
+
300
+ """Descriptor for a protocol message type.
301
+
302
+ Attributes:
303
+ name (str): Name of this protocol message type.
304
+ full_name (str): Fully-qualified name of this protocol message type,
305
+ which will include protocol "package" name and the name of any
306
+ enclosing types.
307
+ containing_type (Descriptor): Reference to the descriptor of the type
308
+ containing us, or None if this is top-level.
309
+ fields (list[FieldDescriptor]): Field descriptors for all fields in
310
+ this type.
311
+ fields_by_number (dict(int, FieldDescriptor)): Same
312
+ :class:`FieldDescriptor` objects as in :attr:`fields`, but indexed
313
+ by "number" attribute in each FieldDescriptor.
314
+ fields_by_name (dict(str, FieldDescriptor)): Same
315
+ :class:`FieldDescriptor` objects as in :attr:`fields`, but indexed by
316
+ "name" attribute in each :class:`FieldDescriptor`.
317
+ nested_types (list[Descriptor]): Descriptor references
318
+ for all protocol message types nested within this one.
319
+ nested_types_by_name (dict(str, Descriptor)): Same Descriptor
320
+ objects as in :attr:`nested_types`, but indexed by "name" attribute
321
+ in each Descriptor.
322
+ enum_types (list[EnumDescriptor]): :class:`EnumDescriptor` references
323
+ for all enums contained within this type.
324
+ enum_types_by_name (dict(str, EnumDescriptor)): Same
325
+ :class:`EnumDescriptor` objects as in :attr:`enum_types`, but
326
+ indexed by "name" attribute in each EnumDescriptor.
327
+ enum_values_by_name (dict(str, EnumValueDescriptor)): Dict mapping
328
+ from enum value name to :class:`EnumValueDescriptor` for that value.
329
+ extensions (list[FieldDescriptor]): All extensions defined directly
330
+ within this message type (NOT within a nested type).
331
+ extensions_by_name (dict(str, FieldDescriptor)): Same FieldDescriptor
332
+ objects as :attr:`extensions`, but indexed by "name" attribute of each
333
+ FieldDescriptor.
334
+ is_extendable (bool): Does this type define any extension ranges?
335
+ oneofs (list[OneofDescriptor]): The list of descriptors for oneof fields
336
+ in this message.
337
+ oneofs_by_name (dict(str, OneofDescriptor)): Same objects as in
338
+ :attr:`oneofs`, but indexed by "name" attribute.
339
+ file (FileDescriptor): Reference to file descriptor.
340
+ is_map_entry: If the message type is a map entry.
341
+
342
+ """
343
+
344
+ if _USE_C_DESCRIPTORS:
345
+ _C_DESCRIPTOR_CLASS = _message.Descriptor
346
+
347
+ def __new__(
348
+ cls,
349
+ name=None,
350
+ full_name=None,
351
+ filename=None,
352
+ containing_type=None,
353
+ fields=None,
354
+ nested_types=None,
355
+ enum_types=None,
356
+ extensions=None,
357
+ options=None,
358
+ serialized_options=None,
359
+ is_extendable=True,
360
+ extension_ranges=None,
361
+ oneofs=None,
362
+ file=None, # pylint: disable=redefined-builtin
363
+ serialized_start=None,
364
+ serialized_end=None,
365
+ syntax=None,
366
+ is_map_entry=False,
367
+ create_key=None):
368
+ _message.Message._CheckCalledFromGeneratedFile()
369
+ return _message.default_pool.FindMessageTypeByName(full_name)
370
+
371
+ # NOTE: The file argument redefining a builtin is nothing we can
372
+ # fix right now since we don't know how many clients already rely on the
373
+ # name of the argument.
374
+ def __init__(self, name, full_name, filename, containing_type, fields,
375
+ nested_types, enum_types, extensions, options=None,
376
+ serialized_options=None,
377
+ is_extendable=True, extension_ranges=None, oneofs=None,
378
+ file=None, serialized_start=None, serialized_end=None, # pylint: disable=redefined-builtin
379
+ syntax=None, is_map_entry=False, create_key=None):
380
+ """Arguments to __init__() are as described in the description
381
+ of Descriptor fields above.
382
+
383
+ Note that filename is an obsolete argument, that is not used anymore.
384
+ Please use file.name to access this as an attribute.
385
+ """
386
+ if create_key is not _internal_create_key:
387
+ _Deprecated('Descriptor')
388
+
389
+ super(Descriptor, self).__init__(
390
+ options, 'MessageOptions', name, full_name, file,
391
+ containing_type, serialized_start=serialized_start,
392
+ serialized_end=serialized_end, serialized_options=serialized_options)
393
+
394
+ # We have fields in addition to fields_by_name and fields_by_number,
395
+ # so that:
396
+ # 1. Clients can index fields by "order in which they're listed."
397
+ # 2. Clients can easily iterate over all fields with the terse
398
+ # syntax: for f in descriptor.fields: ...
399
+ self.fields = fields
400
+ for field in self.fields:
401
+ field.containing_type = self
402
+ field.file = file
403
+ self.fields_by_number = dict((f.number, f) for f in fields)
404
+ self.fields_by_name = dict((f.name, f) for f in fields)
405
+ self._fields_by_camelcase_name = None
406
+
407
+ self.nested_types = nested_types
408
+ for nested_type in nested_types:
409
+ nested_type.containing_type = self
410
+ self.nested_types_by_name = dict((t.name, t) for t in nested_types)
411
+
412
+ self.enum_types = enum_types
413
+ for enum_type in self.enum_types:
414
+ enum_type.containing_type = self
415
+ self.enum_types_by_name = dict((t.name, t) for t in enum_types)
416
+ self.enum_values_by_name = dict(
417
+ (v.name, v) for t in enum_types for v in t.values)
418
+
419
+ self.extensions = extensions
420
+ for extension in self.extensions:
421
+ extension.extension_scope = self
422
+ self.extensions_by_name = dict((f.name, f) for f in extensions)
423
+ self.is_extendable = is_extendable
424
+ self.extension_ranges = extension_ranges
425
+ self.oneofs = oneofs if oneofs is not None else []
426
+ self.oneofs_by_name = dict((o.name, o) for o in self.oneofs)
427
+ for oneof in self.oneofs:
428
+ oneof.containing_type = self
429
+ oneof.file = file
430
+ self._is_map_entry = is_map_entry
431
+
432
+ @property
433
+ def _parent(self):
434
+ return self.containing_type or self.file
435
+
436
+ @property
437
+ def fields_by_camelcase_name(self):
438
+ """Same FieldDescriptor objects as in :attr:`fields`, but indexed by
439
+ :attr:`FieldDescriptor.camelcase_name`.
440
+ """
441
+ if self._fields_by_camelcase_name is None:
442
+ self._fields_by_camelcase_name = dict(
443
+ (f.camelcase_name, f) for f in self.fields)
444
+ return self._fields_by_camelcase_name
445
+
446
+ def EnumValueName(self, enum, value):
447
+ """Returns the string name of an enum value.
448
+
449
+ This is just a small helper method to simplify a common operation.
450
+
451
+ Args:
452
+ enum: string name of the Enum.
453
+ value: int, value of the enum.
454
+
455
+ Returns:
456
+ string name of the enum value.
457
+
458
+ Raises:
459
+ KeyError if either the Enum doesn't exist or the value is not a valid
460
+ value for the enum.
461
+ """
462
+ return self.enum_types_by_name[enum].values_by_number[value].name
463
+
464
+ def CopyToProto(self, proto):
465
+ """Copies this to a descriptor_pb2.DescriptorProto.
466
+
467
+ Args:
468
+ proto: An empty descriptor_pb2.DescriptorProto.
469
+ """
470
+ # This function is overridden to give a better doc comment.
471
+ super(Descriptor, self).CopyToProto(proto)
472
+
473
+
474
+ # TODO: We should have aggressive checking here,
475
+ # for example:
476
+ # * If you specify a repeated field, you should not be allowed
477
+ # to specify a default value.
478
+ # * [Other examples here as needed].
479
+ #
480
+ # TODO: for this and other *Descriptor classes, we
481
+ # might also want to lock things down aggressively (e.g.,
482
+ # prevent clients from setting the attributes). Having
483
+ # stronger invariants here in general will reduce the number
484
+ # of runtime checks we must do in reflection.py...
485
+ class FieldDescriptor(DescriptorBase):
486
+
487
+ """Descriptor for a single field in a .proto file.
488
+
489
+ Attributes:
490
+ name (str): Name of this field, exactly as it appears in .proto.
491
+ full_name (str): Name of this field, including containing scope. This is
492
+ particularly relevant for extensions.
493
+ index (int): Dense, 0-indexed index giving the order that this
494
+ field textually appears within its message in the .proto file.
495
+ number (int): Tag number declared for this field in the .proto file.
496
+
497
+ type (int): (One of the TYPE_* constants below) Declared type.
498
+ cpp_type (int): (One of the CPPTYPE_* constants below) C++ type used to
499
+ represent this field.
500
+
501
+ label (int): (One of the LABEL_* constants below) Tells whether this
502
+ field is optional, required, or repeated.
503
+ has_default_value (bool): True if this field has a default value defined,
504
+ otherwise false.
505
+ default_value (Varies): Default value of this field. Only
506
+ meaningful for non-repeated scalar fields. Repeated fields
507
+ should always set this to [], and non-repeated composite
508
+ fields should always set this to None.
509
+
510
+ containing_type (Descriptor): Descriptor of the protocol message
511
+ type that contains this field. Set by the Descriptor constructor
512
+ if we're passed into one.
513
+ Somewhat confusingly, for extension fields, this is the
514
+ descriptor of the EXTENDED message, not the descriptor
515
+ of the message containing this field. (See is_extension and
516
+ extension_scope below).
517
+ message_type (Descriptor): If a composite field, a descriptor
518
+ of the message type contained in this field. Otherwise, this is None.
519
+ enum_type (EnumDescriptor): If this field contains an enum, a
520
+ descriptor of that enum. Otherwise, this is None.
521
+
522
+ is_extension: True iff this describes an extension field.
523
+ extension_scope (Descriptor): Only meaningful if is_extension is True.
524
+ Gives the message that immediately contains this extension field.
525
+ Will be None iff we're a top-level (file-level) extension field.
526
+
527
+ options (descriptor_pb2.FieldOptions): Protocol message field options or
528
+ None to use default field options.
529
+
530
+ containing_oneof (OneofDescriptor): If the field is a member of a oneof
531
+ union, contains its descriptor. Otherwise, None.
532
+
533
+ file (FileDescriptor): Reference to file descriptor.
534
+ """
535
+
536
+ # Must be consistent with C++ FieldDescriptor::Type enum in
537
+ # descriptor.h.
538
+ #
539
+ # TODO: Find a way to eliminate this repetition.
540
+ TYPE_DOUBLE = 1
541
+ TYPE_FLOAT = 2
542
+ TYPE_INT64 = 3
543
+ TYPE_UINT64 = 4
544
+ TYPE_INT32 = 5
545
+ TYPE_FIXED64 = 6
546
+ TYPE_FIXED32 = 7
547
+ TYPE_BOOL = 8
548
+ TYPE_STRING = 9
549
+ TYPE_GROUP = 10
550
+ TYPE_MESSAGE = 11
551
+ TYPE_BYTES = 12
552
+ TYPE_UINT32 = 13
553
+ TYPE_ENUM = 14
554
+ TYPE_SFIXED32 = 15
555
+ TYPE_SFIXED64 = 16
556
+ TYPE_SINT32 = 17
557
+ TYPE_SINT64 = 18
558
+ MAX_TYPE = 18
559
+
560
+ # Must be consistent with C++ FieldDescriptor::CppType enum in
561
+ # descriptor.h.
562
+ #
563
+ # TODO: Find a way to eliminate this repetition.
564
+ CPPTYPE_INT32 = 1
565
+ CPPTYPE_INT64 = 2
566
+ CPPTYPE_UINT32 = 3
567
+ CPPTYPE_UINT64 = 4
568
+ CPPTYPE_DOUBLE = 5
569
+ CPPTYPE_FLOAT = 6
570
+ CPPTYPE_BOOL = 7
571
+ CPPTYPE_ENUM = 8
572
+ CPPTYPE_STRING = 9
573
+ CPPTYPE_MESSAGE = 10
574
+ MAX_CPPTYPE = 10
575
+
576
+ _PYTHON_TO_CPP_PROTO_TYPE_MAP = {
577
+ TYPE_DOUBLE: CPPTYPE_DOUBLE,
578
+ TYPE_FLOAT: CPPTYPE_FLOAT,
579
+ TYPE_ENUM: CPPTYPE_ENUM,
580
+ TYPE_INT64: CPPTYPE_INT64,
581
+ TYPE_SINT64: CPPTYPE_INT64,
582
+ TYPE_SFIXED64: CPPTYPE_INT64,
583
+ TYPE_UINT64: CPPTYPE_UINT64,
584
+ TYPE_FIXED64: CPPTYPE_UINT64,
585
+ TYPE_INT32: CPPTYPE_INT32,
586
+ TYPE_SFIXED32: CPPTYPE_INT32,
587
+ TYPE_SINT32: CPPTYPE_INT32,
588
+ TYPE_UINT32: CPPTYPE_UINT32,
589
+ TYPE_FIXED32: CPPTYPE_UINT32,
590
+ TYPE_BYTES: CPPTYPE_STRING,
591
+ TYPE_STRING: CPPTYPE_STRING,
592
+ TYPE_BOOL: CPPTYPE_BOOL,
593
+ TYPE_MESSAGE: CPPTYPE_MESSAGE,
594
+ TYPE_GROUP: CPPTYPE_MESSAGE
595
+ }
596
+
597
+ # Must be consistent with C++ FieldDescriptor::Label enum in
598
+ # descriptor.h.
599
+ #
600
+ # TODO: Find a way to eliminate this repetition.
601
+ LABEL_OPTIONAL = 1
602
+ LABEL_REQUIRED = 2
603
+ LABEL_REPEATED = 3
604
+ MAX_LABEL = 3
605
+
606
+ # Must be consistent with C++ constants kMaxNumber, kFirstReservedNumber,
607
+ # and kLastReservedNumber in descriptor.h
608
+ MAX_FIELD_NUMBER = (1 << 29) - 1
609
+ FIRST_RESERVED_FIELD_NUMBER = 19000
610
+ LAST_RESERVED_FIELD_NUMBER = 19999
611
+
612
+ if _USE_C_DESCRIPTORS:
613
+ _C_DESCRIPTOR_CLASS = _message.FieldDescriptor
614
+
615
+ def __new__(cls, name, full_name, index, number, type, cpp_type, label,
616
+ default_value, message_type, enum_type, containing_type,
617
+ is_extension, extension_scope, options=None,
618
+ serialized_options=None,
619
+ has_default_value=True, containing_oneof=None, json_name=None,
620
+ file=None, create_key=None): # pylint: disable=redefined-builtin
621
+ _message.Message._CheckCalledFromGeneratedFile()
622
+ if is_extension:
623
+ return _message.default_pool.FindExtensionByName(full_name)
624
+ else:
625
+ return _message.default_pool.FindFieldByName(full_name)
626
+
627
+ def __init__(self, name, full_name, index, number, type, cpp_type, label,
628
+ default_value, message_type, enum_type, containing_type,
629
+ is_extension, extension_scope, options=None,
630
+ serialized_options=None,
631
+ has_default_value=True, containing_oneof=None, json_name=None,
632
+ file=None, create_key=None): # pylint: disable=redefined-builtin
633
+ """The arguments are as described in the description of FieldDescriptor
634
+ attributes above.
635
+
636
+ Note that containing_type may be None, and may be set later if necessary
637
+ (to deal with circular references between message types, for example).
638
+ Likewise for extension_scope.
639
+ """
640
+ if create_key is not _internal_create_key:
641
+ _Deprecated('FieldDescriptor')
642
+
643
+ super(FieldDescriptor, self).__init__(
644
+ file, options, serialized_options, 'FieldOptions'
645
+ )
646
+ self.name = name
647
+ self.full_name = full_name
648
+ self._camelcase_name = None
649
+ if json_name is None:
650
+ self.json_name = _ToJsonName(name)
651
+ else:
652
+ self.json_name = json_name
653
+ self.index = index
654
+ self.number = number
655
+ self._type = type
656
+ self.cpp_type = cpp_type
657
+ self._label = label
658
+ self.has_default_value = has_default_value
659
+ self.default_value = default_value
660
+ self.containing_type = containing_type
661
+ self.message_type = message_type
662
+ self.enum_type = enum_type
663
+ self.is_extension = is_extension
664
+ self.extension_scope = extension_scope
665
+ self.containing_oneof = containing_oneof
666
+ if api_implementation.Type() == 'python':
667
+ self._cdescriptor = None
668
+ else:
669
+ if is_extension:
670
+ self._cdescriptor = _message.default_pool.FindExtensionByName(full_name)
671
+ else:
672
+ self._cdescriptor = _message.default_pool.FindFieldByName(full_name)
673
+
674
+ @property
675
+ def _parent(self):
676
+ if self.containing_oneof:
677
+ return self.containing_oneof
678
+ if self.is_extension:
679
+ return self.extension_scope or self.file
680
+ return self.containing_type
681
+
682
+ def _InferLegacyFeatures(self, edition, options, features):
683
+ # pylint: disable=g-import-not-at-top
684
+ from google.protobuf import descriptor_pb2
685
+
686
+ if edition >= descriptor_pb2.Edition.EDITION_2023:
687
+ return
688
+
689
+ if self._label == FieldDescriptor.LABEL_REQUIRED:
690
+ features.field_presence = (
691
+ descriptor_pb2.FeatureSet.FieldPresence.LEGACY_REQUIRED
692
+ )
693
+
694
+ if self._type == FieldDescriptor.TYPE_GROUP:
695
+ features.message_encoding = (
696
+ descriptor_pb2.FeatureSet.MessageEncoding.DELIMITED
697
+ )
698
+
699
+ if options.HasField('packed'):
700
+ features.repeated_field_encoding = (
701
+ descriptor_pb2.FeatureSet.RepeatedFieldEncoding.PACKED
702
+ if options.packed
703
+ else descriptor_pb2.FeatureSet.RepeatedFieldEncoding.EXPANDED
704
+ )
705
+
706
+ @property
707
+ def type(self):
708
+ if (
709
+ self._GetFeatures().message_encoding
710
+ == _FEATURESET_MESSAGE_ENCODING_DELIMITED
711
+ and self.message_type
712
+ and not self.message_type.GetOptions().map_entry
713
+ and not self.containing_type.GetOptions().map_entry
714
+ ):
715
+ return FieldDescriptor.TYPE_GROUP
716
+ return self._type
717
+
718
+ @type.setter
719
+ def type(self, val):
720
+ self._type = val
721
+
722
+ @property
723
+ def label(self):
724
+ if (
725
+ self._GetFeatures().field_presence
726
+ == _FEATURESET_FIELD_PRESENCE_LEGACY_REQUIRED
727
+ ):
728
+ return FieldDescriptor.LABEL_REQUIRED
729
+ return self._label
730
+
731
+ @property
732
+ def camelcase_name(self):
733
+ """Camelcase name of this field.
734
+
735
+ Returns:
736
+ str: the name in CamelCase.
737
+ """
738
+ if self._camelcase_name is None:
739
+ self._camelcase_name = _ToCamelCase(self.name)
740
+ return self._camelcase_name
741
+
742
+ @property
743
+ def has_presence(self):
744
+ """Whether the field distinguishes between unpopulated and default values.
745
+
746
+ Raises:
747
+ RuntimeError: singular field that is not linked with message nor file.
748
+ """
749
+ if self.label == FieldDescriptor.LABEL_REPEATED:
750
+ return False
751
+ if (
752
+ self.cpp_type == FieldDescriptor.CPPTYPE_MESSAGE
753
+ or self.is_extension
754
+ or self.containing_oneof
755
+ ):
756
+ return True
757
+
758
+ return (
759
+ self._GetFeatures().field_presence
760
+ != _FEATURESET_FIELD_PRESENCE_IMPLICIT
761
+ )
762
+
763
+ @property
764
+ def is_packed(self):
765
+ """Returns if the field is packed."""
766
+ if self.label != FieldDescriptor.LABEL_REPEATED:
767
+ return False
768
+ field_type = self.type
769
+ if (field_type == FieldDescriptor.TYPE_STRING or
770
+ field_type == FieldDescriptor.TYPE_GROUP or
771
+ field_type == FieldDescriptor.TYPE_MESSAGE or
772
+ field_type == FieldDescriptor.TYPE_BYTES):
773
+ return False
774
+
775
+ return (
776
+ self._GetFeatures().repeated_field_encoding
777
+ == _FEATURESET_REPEATED_FIELD_ENCODING_PACKED
778
+ )
779
+
780
+ @staticmethod
781
+ def ProtoTypeToCppProtoType(proto_type):
782
+ """Converts from a Python proto type to a C++ Proto Type.
783
+
784
+ The Python ProtocolBuffer classes specify both the 'Python' datatype and the
785
+ 'C++' datatype - and they're not the same. This helper method should
786
+ translate from one to another.
787
+
788
+ Args:
789
+ proto_type: the Python proto type (descriptor.FieldDescriptor.TYPE_*)
790
+ Returns:
791
+ int: descriptor.FieldDescriptor.CPPTYPE_*, the C++ type.
792
+ Raises:
793
+ TypeTransformationError: when the Python proto type isn't known.
794
+ """
795
+ try:
796
+ return FieldDescriptor._PYTHON_TO_CPP_PROTO_TYPE_MAP[proto_type]
797
+ except KeyError:
798
+ raise TypeTransformationError('Unknown proto_type: %s' % proto_type)
799
+
800
+
801
+ class EnumDescriptor(_NestedDescriptorBase):
802
+
803
+ """Descriptor for an enum defined in a .proto file.
804
+
805
+ Attributes:
806
+ name (str): Name of the enum type.
807
+ full_name (str): Full name of the type, including package name
808
+ and any enclosing type(s).
809
+
810
+ values (list[EnumValueDescriptor]): List of the values
811
+ in this enum.
812
+ values_by_name (dict(str, EnumValueDescriptor)): Same as :attr:`values`,
813
+ but indexed by the "name" field of each EnumValueDescriptor.
814
+ values_by_number (dict(int, EnumValueDescriptor)): Same as :attr:`values`,
815
+ but indexed by the "number" field of each EnumValueDescriptor.
816
+ containing_type (Descriptor): Descriptor of the immediate containing
817
+ type of this enum, or None if this is an enum defined at the
818
+ top level in a .proto file. Set by Descriptor's constructor
819
+ if we're passed into one.
820
+ file (FileDescriptor): Reference to file descriptor.
821
+ options (descriptor_pb2.EnumOptions): Enum options message or
822
+ None to use default enum options.
823
+ """
824
+
825
+ if _USE_C_DESCRIPTORS:
826
+ _C_DESCRIPTOR_CLASS = _message.EnumDescriptor
827
+
828
+ def __new__(cls, name, full_name, filename, values,
829
+ containing_type=None, options=None,
830
+ serialized_options=None, file=None, # pylint: disable=redefined-builtin
831
+ serialized_start=None, serialized_end=None, create_key=None):
832
+ _message.Message._CheckCalledFromGeneratedFile()
833
+ return _message.default_pool.FindEnumTypeByName(full_name)
834
+
835
+ def __init__(self, name, full_name, filename, values,
836
+ containing_type=None, options=None,
837
+ serialized_options=None, file=None, # pylint: disable=redefined-builtin
838
+ serialized_start=None, serialized_end=None, create_key=None):
839
+ """Arguments are as described in the attribute description above.
840
+
841
+ Note that filename is an obsolete argument, that is not used anymore.
842
+ Please use file.name to access this as an attribute.
843
+ """
844
+ if create_key is not _internal_create_key:
845
+ _Deprecated('EnumDescriptor')
846
+
847
+ super(EnumDescriptor, self).__init__(
848
+ options, 'EnumOptions', name, full_name, file,
849
+ containing_type, serialized_start=serialized_start,
850
+ serialized_end=serialized_end, serialized_options=serialized_options)
851
+
852
+ self.values = values
853
+ for value in self.values:
854
+ value.file = file
855
+ value.type = self
856
+ self.values_by_name = dict((v.name, v) for v in values)
857
+ # Values are reversed to ensure that the first alias is retained.
858
+ self.values_by_number = dict((v.number, v) for v in reversed(values))
859
+
860
+ @property
861
+ def _parent(self):
862
+ return self.containing_type or self.file
863
+
864
+ @property
865
+ def is_closed(self):
866
+ """Returns true whether this is a "closed" enum.
867
+
868
+ This means that it:
869
+ - Has a fixed set of values, rather than being equivalent to an int32.
870
+ - Encountering values not in this set causes them to be treated as unknown
871
+ fields.
872
+ - The first value (i.e., the default) may be nonzero.
873
+
874
+ WARNING: Some runtimes currently have a quirk where non-closed enums are
875
+ treated as closed when used as the type of fields defined in a
876
+ `syntax = proto2;` file. This quirk is not present in all runtimes; as of
877
+ writing, we know that:
878
+
879
+ - C++, Java, and C++-based Python share this quirk.
880
+ - UPB and UPB-based Python do not.
881
+ - PHP and Ruby treat all enums as open regardless of declaration.
882
+
883
+ Care should be taken when using this function to respect the target
884
+ runtime's enum handling quirks.
885
+ """
886
+ return self._GetFeatures().enum_type == _FEATURESET_ENUM_TYPE_CLOSED
887
+
888
+ def CopyToProto(self, proto):
889
+ """Copies this to a descriptor_pb2.EnumDescriptorProto.
890
+
891
+ Args:
892
+ proto (descriptor_pb2.EnumDescriptorProto): An empty descriptor proto.
893
+ """
894
+ # This function is overridden to give a better doc comment.
895
+ super(EnumDescriptor, self).CopyToProto(proto)
896
+
897
+
898
+ class EnumValueDescriptor(DescriptorBase):
899
+
900
+ """Descriptor for a single value within an enum.
901
+
902
+ Attributes:
903
+ name (str): Name of this value.
904
+ index (int): Dense, 0-indexed index giving the order that this
905
+ value appears textually within its enum in the .proto file.
906
+ number (int): Actual number assigned to this enum value.
907
+ type (EnumDescriptor): :class:`EnumDescriptor` to which this value
908
+ belongs. Set by :class:`EnumDescriptor`'s constructor if we're
909
+ passed into one.
910
+ options (descriptor_pb2.EnumValueOptions): Enum value options message or
911
+ None to use default enum value options options.
912
+ """
913
+
914
+ if _USE_C_DESCRIPTORS:
915
+ _C_DESCRIPTOR_CLASS = _message.EnumValueDescriptor
916
+
917
+ def __new__(cls, name, index, number,
918
+ type=None, # pylint: disable=redefined-builtin
919
+ options=None, serialized_options=None, create_key=None):
920
+ _message.Message._CheckCalledFromGeneratedFile()
921
+ # There is no way we can build a complete EnumValueDescriptor with the
922
+ # given parameters (the name of the Enum is not known, for example).
923
+ # Fortunately generated files just pass it to the EnumDescriptor()
924
+ # constructor, which will ignore it, so returning None is good enough.
925
+ return None
926
+
927
+ def __init__(self, name, index, number,
928
+ type=None, # pylint: disable=redefined-builtin
929
+ options=None, serialized_options=None, create_key=None):
930
+ """Arguments are as described in the attribute description above."""
931
+ if create_key is not _internal_create_key:
932
+ _Deprecated('EnumValueDescriptor')
933
+
934
+ super(EnumValueDescriptor, self).__init__(
935
+ type.file if type else None,
936
+ options,
937
+ serialized_options,
938
+ 'EnumValueOptions',
939
+ )
940
+ self.name = name
941
+ self.index = index
942
+ self.number = number
943
+ self.type = type
944
+
945
+ @property
946
+ def _parent(self):
947
+ return self.type
948
+
949
+
950
+ class OneofDescriptor(DescriptorBase):
951
+ """Descriptor for a oneof field.
952
+
953
+ Attributes:
954
+ name (str): Name of the oneof field.
955
+ full_name (str): Full name of the oneof field, including package name.
956
+ index (int): 0-based index giving the order of the oneof field inside
957
+ its containing type.
958
+ containing_type (Descriptor): :class:`Descriptor` of the protocol message
959
+ type that contains this field. Set by the :class:`Descriptor` constructor
960
+ if we're passed into one.
961
+ fields (list[FieldDescriptor]): The list of field descriptors this
962
+ oneof can contain.
963
+ """
964
+
965
+ if _USE_C_DESCRIPTORS:
966
+ _C_DESCRIPTOR_CLASS = _message.OneofDescriptor
967
+
968
+ def __new__(
969
+ cls, name, full_name, index, containing_type, fields, options=None,
970
+ serialized_options=None, create_key=None):
971
+ _message.Message._CheckCalledFromGeneratedFile()
972
+ return _message.default_pool.FindOneofByName(full_name)
973
+
974
+ def __init__(
975
+ self, name, full_name, index, containing_type, fields, options=None,
976
+ serialized_options=None, create_key=None):
977
+ """Arguments are as described in the attribute description above."""
978
+ if create_key is not _internal_create_key:
979
+ _Deprecated('OneofDescriptor')
980
+
981
+ super(OneofDescriptor, self).__init__(
982
+ containing_type.file if containing_type else None,
983
+ options,
984
+ serialized_options,
985
+ 'OneofOptions',
986
+ )
987
+ self.name = name
988
+ self.full_name = full_name
989
+ self.index = index
990
+ self.containing_type = containing_type
991
+ self.fields = fields
992
+
993
+ @property
994
+ def _parent(self):
995
+ return self.containing_type
996
+
997
+
998
+ class ServiceDescriptor(_NestedDescriptorBase):
999
+
1000
+ """Descriptor for a service.
1001
+
1002
+ Attributes:
1003
+ name (str): Name of the service.
1004
+ full_name (str): Full name of the service, including package name.
1005
+ index (int): 0-indexed index giving the order that this services
1006
+ definition appears within the .proto file.
1007
+ methods (list[MethodDescriptor]): List of methods provided by this
1008
+ service.
1009
+ methods_by_name (dict(str, MethodDescriptor)): Same
1010
+ :class:`MethodDescriptor` objects as in :attr:`methods_by_name`, but
1011
+ indexed by "name" attribute in each :class:`MethodDescriptor`.
1012
+ options (descriptor_pb2.ServiceOptions): Service options message or
1013
+ None to use default service options.
1014
+ file (FileDescriptor): Reference to file info.
1015
+ """
1016
+
1017
+ if _USE_C_DESCRIPTORS:
1018
+ _C_DESCRIPTOR_CLASS = _message.ServiceDescriptor
1019
+
1020
+ def __new__(
1021
+ cls,
1022
+ name=None,
1023
+ full_name=None,
1024
+ index=None,
1025
+ methods=None,
1026
+ options=None,
1027
+ serialized_options=None,
1028
+ file=None, # pylint: disable=redefined-builtin
1029
+ serialized_start=None,
1030
+ serialized_end=None,
1031
+ create_key=None):
1032
+ _message.Message._CheckCalledFromGeneratedFile() # pylint: disable=protected-access
1033
+ return _message.default_pool.FindServiceByName(full_name)
1034
+
1035
+ def __init__(self, name, full_name, index, methods, options=None,
1036
+ serialized_options=None, file=None, # pylint: disable=redefined-builtin
1037
+ serialized_start=None, serialized_end=None, create_key=None):
1038
+ if create_key is not _internal_create_key:
1039
+ _Deprecated('ServiceDescriptor')
1040
+
1041
+ super(ServiceDescriptor, self).__init__(
1042
+ options, 'ServiceOptions', name, full_name, file,
1043
+ None, serialized_start=serialized_start,
1044
+ serialized_end=serialized_end, serialized_options=serialized_options)
1045
+ self.index = index
1046
+ self.methods = methods
1047
+ self.methods_by_name = dict((m.name, m) for m in methods)
1048
+ # Set the containing service for each method in this service.
1049
+ for method in self.methods:
1050
+ method.file = self.file
1051
+ method.containing_service = self
1052
+
1053
+ @property
1054
+ def _parent(self):
1055
+ return self.file
1056
+
1057
+ def FindMethodByName(self, name):
1058
+ """Searches for the specified method, and returns its descriptor.
1059
+
1060
+ Args:
1061
+ name (str): Name of the method.
1062
+
1063
+ Returns:
1064
+ MethodDescriptor: The descriptor for the requested method.
1065
+
1066
+ Raises:
1067
+ KeyError: if the method cannot be found in the service.
1068
+ """
1069
+ return self.methods_by_name[name]
1070
+
1071
+ def CopyToProto(self, proto):
1072
+ """Copies this to a descriptor_pb2.ServiceDescriptorProto.
1073
+
1074
+ Args:
1075
+ proto (descriptor_pb2.ServiceDescriptorProto): An empty descriptor proto.
1076
+ """
1077
+ # This function is overridden to give a better doc comment.
1078
+ super(ServiceDescriptor, self).CopyToProto(proto)
1079
+
1080
+
1081
+ class MethodDescriptor(DescriptorBase):
1082
+
1083
+ """Descriptor for a method in a service.
1084
+
1085
+ Attributes:
1086
+ name (str): Name of the method within the service.
1087
+ full_name (str): Full name of method.
1088
+ index (int): 0-indexed index of the method inside the service.
1089
+ containing_service (ServiceDescriptor): The service that contains this
1090
+ method.
1091
+ input_type (Descriptor): The descriptor of the message that this method
1092
+ accepts.
1093
+ output_type (Descriptor): The descriptor of the message that this method
1094
+ returns.
1095
+ client_streaming (bool): Whether this method uses client streaming.
1096
+ server_streaming (bool): Whether this method uses server streaming.
1097
+ options (descriptor_pb2.MethodOptions or None): Method options message, or
1098
+ None to use default method options.
1099
+ """
1100
+
1101
+ if _USE_C_DESCRIPTORS:
1102
+ _C_DESCRIPTOR_CLASS = _message.MethodDescriptor
1103
+
1104
+ def __new__(cls,
1105
+ name,
1106
+ full_name,
1107
+ index,
1108
+ containing_service,
1109
+ input_type,
1110
+ output_type,
1111
+ client_streaming=False,
1112
+ server_streaming=False,
1113
+ options=None,
1114
+ serialized_options=None,
1115
+ create_key=None):
1116
+ _message.Message._CheckCalledFromGeneratedFile() # pylint: disable=protected-access
1117
+ return _message.default_pool.FindMethodByName(full_name)
1118
+
1119
+ def __init__(self,
1120
+ name,
1121
+ full_name,
1122
+ index,
1123
+ containing_service,
1124
+ input_type,
1125
+ output_type,
1126
+ client_streaming=False,
1127
+ server_streaming=False,
1128
+ options=None,
1129
+ serialized_options=None,
1130
+ create_key=None):
1131
+ """The arguments are as described in the description of MethodDescriptor
1132
+ attributes above.
1133
+
1134
+ Note that containing_service may be None, and may be set later if necessary.
1135
+ """
1136
+ if create_key is not _internal_create_key:
1137
+ _Deprecated('MethodDescriptor')
1138
+
1139
+ super(MethodDescriptor, self).__init__(
1140
+ containing_service.file if containing_service else None,
1141
+ options,
1142
+ serialized_options,
1143
+ 'MethodOptions',
1144
+ )
1145
+ self.name = name
1146
+ self.full_name = full_name
1147
+ self.index = index
1148
+ self.containing_service = containing_service
1149
+ self.input_type = input_type
1150
+ self.output_type = output_type
1151
+ self.client_streaming = client_streaming
1152
+ self.server_streaming = server_streaming
1153
+
1154
+ @property
1155
+ def _parent(self):
1156
+ return self.containing_service
1157
+
1158
+ def CopyToProto(self, proto):
1159
+ """Copies this to a descriptor_pb2.MethodDescriptorProto.
1160
+
1161
+ Args:
1162
+ proto (descriptor_pb2.MethodDescriptorProto): An empty descriptor proto.
1163
+
1164
+ Raises:
1165
+ Error: If self couldn't be serialized, due to too few constructor
1166
+ arguments.
1167
+ """
1168
+ if self.containing_service is not None:
1169
+ from google.protobuf import descriptor_pb2
1170
+ service_proto = descriptor_pb2.ServiceDescriptorProto()
1171
+ self.containing_service.CopyToProto(service_proto)
1172
+ proto.CopyFrom(service_proto.method[self.index])
1173
+ else:
1174
+ raise Error('Descriptor does not contain a service.')
1175
+
1176
+
1177
+ class FileDescriptor(DescriptorBase):
1178
+ """Descriptor for a file. Mimics the descriptor_pb2.FileDescriptorProto.
1179
+
1180
+ Note that :attr:`enum_types_by_name`, :attr:`extensions_by_name`, and
1181
+ :attr:`dependencies` fields are only set by the
1182
+ :py:mod:`google.protobuf.message_factory` module, and not by the generated
1183
+ proto code.
1184
+
1185
+ Attributes:
1186
+ name (str): Name of file, relative to root of source tree.
1187
+ package (str): Name of the package
1188
+ edition (Edition): Enum value indicating edition of the file
1189
+ serialized_pb (bytes): Byte string of serialized
1190
+ :class:`descriptor_pb2.FileDescriptorProto`.
1191
+ dependencies (list[FileDescriptor]): List of other :class:`FileDescriptor`
1192
+ objects this :class:`FileDescriptor` depends on.
1193
+ public_dependencies (list[FileDescriptor]): A subset of
1194
+ :attr:`dependencies`, which were declared as "public".
1195
+ message_types_by_name (dict(str, Descriptor)): Mapping from message names to
1196
+ their :class:`Descriptor`.
1197
+ enum_types_by_name (dict(str, EnumDescriptor)): Mapping from enum names to
1198
+ their :class:`EnumDescriptor`.
1199
+ extensions_by_name (dict(str, FieldDescriptor)): Mapping from extension
1200
+ names declared at file scope to their :class:`FieldDescriptor`.
1201
+ services_by_name (dict(str, ServiceDescriptor)): Mapping from services'
1202
+ names to their :class:`ServiceDescriptor`.
1203
+ pool (DescriptorPool): The pool this descriptor belongs to. When not passed
1204
+ to the constructor, the global default pool is used.
1205
+ """
1206
+
1207
+ if _USE_C_DESCRIPTORS:
1208
+ _C_DESCRIPTOR_CLASS = _message.FileDescriptor
1209
+
1210
+ def __new__(
1211
+ cls,
1212
+ name,
1213
+ package,
1214
+ options=None,
1215
+ serialized_options=None,
1216
+ serialized_pb=None,
1217
+ dependencies=None,
1218
+ public_dependencies=None,
1219
+ syntax=None,
1220
+ edition=None,
1221
+ pool=None,
1222
+ create_key=None,
1223
+ ):
1224
+ # FileDescriptor() is called from various places, not only from generated
1225
+ # files, to register dynamic proto files and messages.
1226
+ # pylint: disable=g-explicit-bool-comparison
1227
+ if serialized_pb:
1228
+ return _message.default_pool.AddSerializedFile(serialized_pb)
1229
+ else:
1230
+ return super(FileDescriptor, cls).__new__(cls)
1231
+
1232
+ def __init__(
1233
+ self,
1234
+ name,
1235
+ package,
1236
+ options=None,
1237
+ serialized_options=None,
1238
+ serialized_pb=None,
1239
+ dependencies=None,
1240
+ public_dependencies=None,
1241
+ syntax=None,
1242
+ edition=None,
1243
+ pool=None,
1244
+ create_key=None,
1245
+ ):
1246
+ """Constructor."""
1247
+ if create_key is not _internal_create_key:
1248
+ _Deprecated('FileDescriptor')
1249
+
1250
+ super(FileDescriptor, self).__init__(
1251
+ self, options, serialized_options, 'FileOptions'
1252
+ )
1253
+
1254
+ if edition and edition != 'EDITION_UNKNOWN':
1255
+ self._edition = edition
1256
+ elif syntax == 'proto3':
1257
+ self._edition = 'EDITION_PROTO3'
1258
+ else:
1259
+ self._edition = 'EDITION_PROTO2'
1260
+
1261
+ if pool is None:
1262
+ from google.protobuf import descriptor_pool
1263
+ pool = descriptor_pool.Default()
1264
+ self.pool = pool
1265
+ self.message_types_by_name = {}
1266
+ self.name = name
1267
+ self.package = package
1268
+ self.serialized_pb = serialized_pb
1269
+
1270
+ self.enum_types_by_name = {}
1271
+ self.extensions_by_name = {}
1272
+ self.services_by_name = {}
1273
+ self.dependencies = (dependencies or [])
1274
+ self.public_dependencies = (public_dependencies or [])
1275
+
1276
+ def CopyToProto(self, proto):
1277
+ """Copies this to a descriptor_pb2.FileDescriptorProto.
1278
+
1279
+ Args:
1280
+ proto: An empty descriptor_pb2.FileDescriptorProto.
1281
+ """
1282
+ proto.ParseFromString(self.serialized_pb)
1283
+
1284
+ @property
1285
+ def _parent(self):
1286
+ return None
1287
+
1288
+
1289
+ def _ParseOptions(message, string):
1290
+ """Parses serialized options.
1291
+
1292
+ This helper function is used to parse serialized options in generated
1293
+ proto2 files. It must not be used outside proto2.
1294
+ """
1295
+ message.ParseFromString(string)
1296
+ return message
1297
+
1298
+
1299
+ def _ToCamelCase(name):
1300
+ """Converts name to camel-case and returns it."""
1301
+ capitalize_next = False
1302
+ result = []
1303
+
1304
+ for c in name:
1305
+ if c == '_':
1306
+ if result:
1307
+ capitalize_next = True
1308
+ elif capitalize_next:
1309
+ result.append(c.upper())
1310
+ capitalize_next = False
1311
+ else:
1312
+ result += c
1313
+
1314
+ # Lower-case the first letter.
1315
+ if result and result[0].isupper():
1316
+ result[0] = result[0].lower()
1317
+ return ''.join(result)
1318
+
1319
+
1320
+ def _OptionsOrNone(descriptor_proto):
1321
+ """Returns the value of the field `options`, or None if it is not set."""
1322
+ if descriptor_proto.HasField('options'):
1323
+ return descriptor_proto.options
1324
+ else:
1325
+ return None
1326
+
1327
+
1328
+ def _ToJsonName(name):
1329
+ """Converts name to Json name and returns it."""
1330
+ capitalize_next = False
1331
+ result = []
1332
+
1333
+ for c in name:
1334
+ if c == '_':
1335
+ capitalize_next = True
1336
+ elif capitalize_next:
1337
+ result.append(c.upper())
1338
+ capitalize_next = False
1339
+ else:
1340
+ result += c
1341
+
1342
+ return ''.join(result)
1343
+
1344
+
1345
+ def MakeDescriptor(
1346
+ desc_proto,
1347
+ package='',
1348
+ build_file_if_cpp=True,
1349
+ syntax=None,
1350
+ edition=None,
1351
+ file_desc=None,
1352
+ ):
1353
+ """Make a protobuf Descriptor given a DescriptorProto protobuf.
1354
+
1355
+ Handles nested descriptors. Note that this is limited to the scope of defining
1356
+ a message inside of another message. Composite fields can currently only be
1357
+ resolved if the message is defined in the same scope as the field.
1358
+
1359
+ Args:
1360
+ desc_proto: The descriptor_pb2.DescriptorProto protobuf message.
1361
+ package: Optional package name for the new message Descriptor (string).
1362
+ build_file_if_cpp: Update the C++ descriptor pool if api matches. Set to
1363
+ False on recursion, so no duplicates are created.
1364
+ syntax: The syntax/semantics that should be used. Set to "proto3" to get
1365
+ proto3 field presence semantics.
1366
+ edition: The edition that should be used if syntax is "edition".
1367
+ file_desc: A FileDescriptor to place this descriptor into.
1368
+
1369
+ Returns:
1370
+ A Descriptor for protobuf messages.
1371
+ """
1372
+ # pylint: disable=g-import-not-at-top
1373
+ from google.protobuf import descriptor_pb2
1374
+
1375
+ # Generate a random name for this proto file to prevent conflicts with any
1376
+ # imported ones. We need to specify a file name so the descriptor pool
1377
+ # accepts our FileDescriptorProto, but it is not important what that file
1378
+ # name is actually set to.
1379
+ proto_name = binascii.hexlify(os.urandom(16)).decode('ascii')
1380
+
1381
+ if package:
1382
+ file_name = os.path.join(package.replace('.', '/'), proto_name + '.proto')
1383
+ else:
1384
+ file_name = proto_name + '.proto'
1385
+
1386
+ if api_implementation.Type() != 'python' and build_file_if_cpp:
1387
+ # The C++ implementation requires all descriptors to be backed by the same
1388
+ # definition in the C++ descriptor pool. To do this, we build a
1389
+ # FileDescriptorProto with the same definition as this descriptor and build
1390
+ # it into the pool.
1391
+ file_descriptor_proto = descriptor_pb2.FileDescriptorProto()
1392
+ file_descriptor_proto.message_type.add().MergeFrom(desc_proto)
1393
+
1394
+ if package:
1395
+ file_descriptor_proto.package = package
1396
+ file_descriptor_proto.name = file_name
1397
+
1398
+ _message.default_pool.Add(file_descriptor_proto)
1399
+ result = _message.default_pool.FindFileByName(file_descriptor_proto.name)
1400
+
1401
+ if _USE_C_DESCRIPTORS:
1402
+ return result.message_types_by_name[desc_proto.name]
1403
+
1404
+ if file_desc is None:
1405
+ file_desc = FileDescriptor(
1406
+ pool=None,
1407
+ name=file_name,
1408
+ package=package,
1409
+ syntax=syntax,
1410
+ edition=edition,
1411
+ options=None,
1412
+ serialized_pb='',
1413
+ dependencies=[],
1414
+ public_dependencies=[],
1415
+ create_key=_internal_create_key,
1416
+ )
1417
+ full_message_name = [desc_proto.name]
1418
+ if package: full_message_name.insert(0, package)
1419
+
1420
+ # Create Descriptors for enum types
1421
+ enum_types = {}
1422
+ for enum_proto in desc_proto.enum_type:
1423
+ full_name = '.'.join(full_message_name + [enum_proto.name])
1424
+ enum_desc = EnumDescriptor(
1425
+ enum_proto.name,
1426
+ full_name,
1427
+ None,
1428
+ [
1429
+ EnumValueDescriptor(
1430
+ enum_val.name,
1431
+ ii,
1432
+ enum_val.number,
1433
+ create_key=_internal_create_key,
1434
+ )
1435
+ for ii, enum_val in enumerate(enum_proto.value)
1436
+ ],
1437
+ file=file_desc,
1438
+ create_key=_internal_create_key,
1439
+ )
1440
+ enum_types[full_name] = enum_desc
1441
+
1442
+ # Create Descriptors for nested types
1443
+ nested_types = {}
1444
+ for nested_proto in desc_proto.nested_type:
1445
+ full_name = '.'.join(full_message_name + [nested_proto.name])
1446
+ # Nested types are just those defined inside of the message, not all types
1447
+ # used by fields in the message, so no loops are possible here.
1448
+ nested_desc = MakeDescriptor(
1449
+ nested_proto,
1450
+ package='.'.join(full_message_name),
1451
+ build_file_if_cpp=False,
1452
+ syntax=syntax,
1453
+ edition=edition,
1454
+ file_desc=file_desc,
1455
+ )
1456
+ nested_types[full_name] = nested_desc
1457
+
1458
+ fields = []
1459
+ for field_proto in desc_proto.field:
1460
+ full_name = '.'.join(full_message_name + [field_proto.name])
1461
+ enum_desc = None
1462
+ nested_desc = None
1463
+ if field_proto.json_name:
1464
+ json_name = field_proto.json_name
1465
+ else:
1466
+ json_name = None
1467
+ if field_proto.HasField('type_name'):
1468
+ type_name = field_proto.type_name
1469
+ full_type_name = '.'.join(full_message_name +
1470
+ [type_name[type_name.rfind('.')+1:]])
1471
+ if full_type_name in nested_types:
1472
+ nested_desc = nested_types[full_type_name]
1473
+ elif full_type_name in enum_types:
1474
+ enum_desc = enum_types[full_type_name]
1475
+ # Else type_name references a non-local type, which isn't implemented
1476
+ field = FieldDescriptor(
1477
+ field_proto.name,
1478
+ full_name,
1479
+ field_proto.number - 1,
1480
+ field_proto.number,
1481
+ field_proto.type,
1482
+ FieldDescriptor.ProtoTypeToCppProtoType(field_proto.type),
1483
+ field_proto.label,
1484
+ None,
1485
+ nested_desc,
1486
+ enum_desc,
1487
+ None,
1488
+ False,
1489
+ None,
1490
+ options=_OptionsOrNone(field_proto),
1491
+ has_default_value=False,
1492
+ json_name=json_name,
1493
+ file=file_desc,
1494
+ create_key=_internal_create_key,
1495
+ )
1496
+ fields.append(field)
1497
+
1498
+ desc_name = '.'.join(full_message_name)
1499
+ return Descriptor(
1500
+ desc_proto.name,
1501
+ desc_name,
1502
+ None,
1503
+ None,
1504
+ fields,
1505
+ list(nested_types.values()),
1506
+ list(enum_types.values()),
1507
+ [],
1508
+ options=_OptionsOrNone(desc_proto),
1509
+ file=file_desc,
1510
+ create_key=_internal_create_key,
1511
+ )
.venv/lib/python3.11/site-packages/google/protobuf/descriptor_database.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Protocol Buffers - Google's data interchange format
2
+ # Copyright 2008 Google Inc. All rights reserved.
3
+ #
4
+ # Use of this source code is governed by a BSD-style
5
+ # license that can be found in the LICENSE file or at
6
+ # https://developers.google.com/open-source/licenses/bsd
7
+
8
+ """Provides a container for DescriptorProtos."""
9
+
10
+ __author__ = 'matthewtoia@google.com (Matt Toia)'
11
+
12
+ import warnings
13
+
14
+
15
+ class Error(Exception):
16
+ pass
17
+
18
+
19
+ class DescriptorDatabaseConflictingDefinitionError(Error):
20
+ """Raised when a proto is added with the same name & different descriptor."""
21
+
22
+
23
+ class DescriptorDatabase(object):
24
+ """A container accepting FileDescriptorProtos and maps DescriptorProtos."""
25
+
26
+ def __init__(self):
27
+ self._file_desc_protos_by_file = {}
28
+ self._file_desc_protos_by_symbol = {}
29
+
30
+ def Add(self, file_desc_proto):
31
+ """Adds the FileDescriptorProto and its types to this database.
32
+
33
+ Args:
34
+ file_desc_proto: The FileDescriptorProto to add.
35
+ Raises:
36
+ DescriptorDatabaseConflictingDefinitionError: if an attempt is made to
37
+ add a proto with the same name but different definition than an
38
+ existing proto in the database.
39
+ """
40
+ proto_name = file_desc_proto.name
41
+ if proto_name not in self._file_desc_protos_by_file:
42
+ self._file_desc_protos_by_file[proto_name] = file_desc_proto
43
+ elif self._file_desc_protos_by_file[proto_name] != file_desc_proto:
44
+ raise DescriptorDatabaseConflictingDefinitionError(
45
+ '%s already added, but with different descriptor.' % proto_name)
46
+ else:
47
+ return
48
+
49
+ # Add all the top-level descriptors to the index.
50
+ package = file_desc_proto.package
51
+ for message in file_desc_proto.message_type:
52
+ for name in _ExtractSymbols(message, package):
53
+ self._AddSymbol(name, file_desc_proto)
54
+ for enum in file_desc_proto.enum_type:
55
+ self._AddSymbol(('.'.join((package, enum.name))), file_desc_proto)
56
+ for enum_value in enum.value:
57
+ self._file_desc_protos_by_symbol[
58
+ '.'.join((package, enum_value.name))] = file_desc_proto
59
+ for extension in file_desc_proto.extension:
60
+ self._AddSymbol(('.'.join((package, extension.name))), file_desc_proto)
61
+ for service in file_desc_proto.service:
62
+ self._AddSymbol(('.'.join((package, service.name))), file_desc_proto)
63
+
64
+ def FindFileByName(self, name):
65
+ """Finds the file descriptor proto by file name.
66
+
67
+ Typically the file name is a relative path ending to a .proto file. The
68
+ proto with the given name will have to have been added to this database
69
+ using the Add method or else an error will be raised.
70
+
71
+ Args:
72
+ name: The file name to find.
73
+
74
+ Returns:
75
+ The file descriptor proto matching the name.
76
+
77
+ Raises:
78
+ KeyError if no file by the given name was added.
79
+ """
80
+
81
+ return self._file_desc_protos_by_file[name]
82
+
83
+ def FindFileContainingSymbol(self, symbol):
84
+ """Finds the file descriptor proto containing the specified symbol.
85
+
86
+ The symbol should be a fully qualified name including the file descriptor's
87
+ package and any containing messages. Some examples:
88
+
89
+ 'some.package.name.Message'
90
+ 'some.package.name.Message.NestedEnum'
91
+ 'some.package.name.Message.some_field'
92
+
93
+ The file descriptor proto containing the specified symbol must be added to
94
+ this database using the Add method or else an error will be raised.
95
+
96
+ Args:
97
+ symbol: The fully qualified symbol name.
98
+
99
+ Returns:
100
+ The file descriptor proto containing the symbol.
101
+
102
+ Raises:
103
+ KeyError if no file contains the specified symbol.
104
+ """
105
+ try:
106
+ return self._file_desc_protos_by_symbol[symbol]
107
+ except KeyError:
108
+ # Fields, enum values, and nested extensions are not in
109
+ # _file_desc_protos_by_symbol. Try to find the top level
110
+ # descriptor. Non-existent nested symbol under a valid top level
111
+ # descriptor can also be found. The behavior is the same with
112
+ # protobuf C++.
113
+ top_level, _, _ = symbol.rpartition('.')
114
+ try:
115
+ return self._file_desc_protos_by_symbol[top_level]
116
+ except KeyError:
117
+ # Raise the original symbol as a KeyError for better diagnostics.
118
+ raise KeyError(symbol)
119
+
120
+ def FindFileContainingExtension(self, extendee_name, extension_number):
121
+ # TODO: implement this API.
122
+ return None
123
+
124
+ def FindAllExtensionNumbers(self, extendee_name):
125
+ # TODO: implement this API.
126
+ return []
127
+
128
+ def _AddSymbol(self, name, file_desc_proto):
129
+ if name in self._file_desc_protos_by_symbol:
130
+ warn_msg = ('Conflict register for file "' + file_desc_proto.name +
131
+ '": ' + name +
132
+ ' is already defined in file "' +
133
+ self._file_desc_protos_by_symbol[name].name + '"')
134
+ warnings.warn(warn_msg, RuntimeWarning)
135
+ self._file_desc_protos_by_symbol[name] = file_desc_proto
136
+
137
+
138
+ def _ExtractSymbols(desc_proto, package):
139
+ """Pulls out all the symbols from a descriptor proto.
140
+
141
+ Args:
142
+ desc_proto: The proto to extract symbols from.
143
+ package: The package containing the descriptor type.
144
+
145
+ Yields:
146
+ The fully qualified name found in the descriptor.
147
+ """
148
+ message_name = package + '.' + desc_proto.name if package else desc_proto.name
149
+ yield message_name
150
+ for nested_type in desc_proto.nested_type:
151
+ for symbol in _ExtractSymbols(nested_type, message_name):
152
+ yield symbol
153
+ for enum_type in desc_proto.enum_type:
154
+ yield '.'.join((message_name, enum_type.name))
.venv/lib/python3.11/site-packages/google/protobuf/descriptor_pb2.py ADDED
The diff for this file is too large to render. See raw diff
 
.venv/lib/python3.11/site-packages/google/protobuf/descriptor_pool.py ADDED
@@ -0,0 +1,1355 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Protocol Buffers - Google's data interchange format
2
+ # Copyright 2008 Google Inc. All rights reserved.
3
+ #
4
+ # Use of this source code is governed by a BSD-style
5
+ # license that can be found in the LICENSE file or at
6
+ # https://developers.google.com/open-source/licenses/bsd
7
+
8
+ """Provides DescriptorPool to use as a container for proto2 descriptors.
9
+
10
+ The DescriptorPool is used in conjection with a DescriptorDatabase to maintain
11
+ a collection of protocol buffer descriptors for use when dynamically creating
12
+ message types at runtime.
13
+
14
+ For most applications protocol buffers should be used via modules generated by
15
+ the protocol buffer compiler tool. This should only be used when the type of
16
+ protocol buffers used in an application or library cannot be predetermined.
17
+
18
+ Below is a straightforward example on how to use this class::
19
+
20
+ pool = DescriptorPool()
21
+ file_descriptor_protos = [ ... ]
22
+ for file_descriptor_proto in file_descriptor_protos:
23
+ pool.Add(file_descriptor_proto)
24
+ my_message_descriptor = pool.FindMessageTypeByName('some.package.MessageType')
25
+
26
+ The message descriptor can be used in conjunction with the message_factory
27
+ module in order to create a protocol buffer class that can be encoded and
28
+ decoded.
29
+
30
+ If you want to get a Python class for the specified proto, use the
31
+ helper functions inside google.protobuf.message_factory
32
+ directly instead of this class.
33
+ """
34
+
35
+ __author__ = 'matthewtoia@google.com (Matt Toia)'
36
+
37
+ import collections
38
+ import threading
39
+ import warnings
40
+
41
+ from google.protobuf import descriptor
42
+ from google.protobuf import descriptor_database
43
+ from google.protobuf import text_encoding
44
+ from google.protobuf.internal import python_edition_defaults
45
+ from google.protobuf.internal import python_message
46
+
47
+ _USE_C_DESCRIPTORS = descriptor._USE_C_DESCRIPTORS # pylint: disable=protected-access
48
+
49
+
50
+ def _NormalizeFullyQualifiedName(name):
51
+ """Remove leading period from fully-qualified type name.
52
+
53
+ Due to b/13860351 in descriptor_database.py, types in the root namespace are
54
+ generated with a leading period. This function removes that prefix.
55
+
56
+ Args:
57
+ name (str): The fully-qualified symbol name.
58
+
59
+ Returns:
60
+ str: The normalized fully-qualified symbol name.
61
+ """
62
+ return name.lstrip('.')
63
+
64
+
65
+ def _OptionsOrNone(descriptor_proto):
66
+ """Returns the value of the field `options`, or None if it is not set."""
67
+ if descriptor_proto.HasField('options'):
68
+ return descriptor_proto.options
69
+ else:
70
+ return None
71
+
72
+
73
+ def _IsMessageSetExtension(field):
74
+ return (field.is_extension and
75
+ field.containing_type.has_options and
76
+ field.containing_type.GetOptions().message_set_wire_format and
77
+ field.type == descriptor.FieldDescriptor.TYPE_MESSAGE and
78
+ field.label == descriptor.FieldDescriptor.LABEL_OPTIONAL)
79
+
80
+ _edition_defaults_lock = threading.Lock()
81
+
82
+
83
+ class DescriptorPool(object):
84
+ """A collection of protobufs dynamically constructed by descriptor protos."""
85
+
86
+ if _USE_C_DESCRIPTORS:
87
+
88
+ def __new__(cls, descriptor_db=None):
89
+ # pylint: disable=protected-access
90
+ return descriptor._message.DescriptorPool(descriptor_db)
91
+
92
+ def __init__(
93
+ self, descriptor_db=None, use_deprecated_legacy_json_field_conflicts=False
94
+ ):
95
+ """Initializes a Pool of proto buffs.
96
+
97
+ The descriptor_db argument to the constructor is provided to allow
98
+ specialized file descriptor proto lookup code to be triggered on demand. An
99
+ example would be an implementation which will read and compile a file
100
+ specified in a call to FindFileByName() and not require the call to Add()
101
+ at all. Results from this database will be cached internally here as well.
102
+
103
+ Args:
104
+ descriptor_db: A secondary source of file descriptors.
105
+ use_deprecated_legacy_json_field_conflicts: Unused, for compatibility with
106
+ C++.
107
+ """
108
+
109
+ self._internal_db = descriptor_database.DescriptorDatabase()
110
+ self._descriptor_db = descriptor_db
111
+ self._descriptors = {}
112
+ self._enum_descriptors = {}
113
+ self._service_descriptors = {}
114
+ self._file_descriptors = {}
115
+ self._toplevel_extensions = {}
116
+ self._top_enum_values = {}
117
+ # We store extensions in two two-level mappings: The first key is the
118
+ # descriptor of the message being extended, the second key is the extension
119
+ # full name or its tag number.
120
+ self._extensions_by_name = collections.defaultdict(dict)
121
+ self._extensions_by_number = collections.defaultdict(dict)
122
+ self._serialized_edition_defaults = (
123
+ python_edition_defaults._PROTOBUF_INTERNAL_PYTHON_EDITION_DEFAULTS
124
+ )
125
+ self._edition_defaults = None
126
+ self._feature_cache = dict()
127
+
128
+ def _CheckConflictRegister(self, desc, desc_name, file_name):
129
+ """Check if the descriptor name conflicts with another of the same name.
130
+
131
+ Args:
132
+ desc: Descriptor of a message, enum, service, extension or enum value.
133
+ desc_name (str): the full name of desc.
134
+ file_name (str): The file name of descriptor.
135
+ """
136
+ for register, descriptor_type in [
137
+ (self._descriptors, descriptor.Descriptor),
138
+ (self._enum_descriptors, descriptor.EnumDescriptor),
139
+ (self._service_descriptors, descriptor.ServiceDescriptor),
140
+ (self._toplevel_extensions, descriptor.FieldDescriptor),
141
+ (self._top_enum_values, descriptor.EnumValueDescriptor)]:
142
+ if desc_name in register:
143
+ old_desc = register[desc_name]
144
+ if isinstance(old_desc, descriptor.EnumValueDescriptor):
145
+ old_file = old_desc.type.file.name
146
+ else:
147
+ old_file = old_desc.file.name
148
+
149
+ if not isinstance(desc, descriptor_type) or (
150
+ old_file != file_name):
151
+ error_msg = ('Conflict register for file "' + file_name +
152
+ '": ' + desc_name +
153
+ ' is already defined in file "' +
154
+ old_file + '". Please fix the conflict by adding '
155
+ 'package name on the proto file, or use different '
156
+ 'name for the duplication.')
157
+ if isinstance(desc, descriptor.EnumValueDescriptor):
158
+ error_msg += ('\nNote: enum values appear as '
159
+ 'siblings of the enum type instead of '
160
+ 'children of it.')
161
+
162
+ raise TypeError(error_msg)
163
+
164
+ return
165
+
166
+ def Add(self, file_desc_proto):
167
+ """Adds the FileDescriptorProto and its types to this pool.
168
+
169
+ Args:
170
+ file_desc_proto (FileDescriptorProto): The file descriptor to add.
171
+ """
172
+
173
+ self._internal_db.Add(file_desc_proto)
174
+
175
+ def AddSerializedFile(self, serialized_file_desc_proto):
176
+ """Adds the FileDescriptorProto and its types to this pool.
177
+
178
+ Args:
179
+ serialized_file_desc_proto (bytes): A bytes string, serialization of the
180
+ :class:`FileDescriptorProto` to add.
181
+
182
+ Returns:
183
+ FileDescriptor: Descriptor for the added file.
184
+ """
185
+
186
+ # pylint: disable=g-import-not-at-top
187
+ from google.protobuf import descriptor_pb2
188
+ file_desc_proto = descriptor_pb2.FileDescriptorProto.FromString(
189
+ serialized_file_desc_proto)
190
+ file_desc = self._ConvertFileProtoToFileDescriptor(file_desc_proto)
191
+ file_desc.serialized_pb = serialized_file_desc_proto
192
+ return file_desc
193
+
194
+ # Never call this method. It is for internal usage only.
195
+ def _AddDescriptor(self, desc):
196
+ """Adds a Descriptor to the pool, non-recursively.
197
+
198
+ If the Descriptor contains nested messages or enums, the caller must
199
+ explicitly register them. This method also registers the FileDescriptor
200
+ associated with the message.
201
+
202
+ Args:
203
+ desc: A Descriptor.
204
+ """
205
+ if not isinstance(desc, descriptor.Descriptor):
206
+ raise TypeError('Expected instance of descriptor.Descriptor.')
207
+
208
+ self._CheckConflictRegister(desc, desc.full_name, desc.file.name)
209
+
210
+ self._descriptors[desc.full_name] = desc
211
+ self._AddFileDescriptor(desc.file)
212
+
213
+ # Never call this method. It is for internal usage only.
214
+ def _AddEnumDescriptor(self, enum_desc):
215
+ """Adds an EnumDescriptor to the pool.
216
+
217
+ This method also registers the FileDescriptor associated with the enum.
218
+
219
+ Args:
220
+ enum_desc: An EnumDescriptor.
221
+ """
222
+
223
+ if not isinstance(enum_desc, descriptor.EnumDescriptor):
224
+ raise TypeError('Expected instance of descriptor.EnumDescriptor.')
225
+
226
+ file_name = enum_desc.file.name
227
+ self._CheckConflictRegister(enum_desc, enum_desc.full_name, file_name)
228
+ self._enum_descriptors[enum_desc.full_name] = enum_desc
229
+
230
+ # Top enum values need to be indexed.
231
+ # Count the number of dots to see whether the enum is toplevel or nested
232
+ # in a message. We cannot use enum_desc.containing_type at this stage.
233
+ if enum_desc.file.package:
234
+ top_level = (enum_desc.full_name.count('.')
235
+ - enum_desc.file.package.count('.') == 1)
236
+ else:
237
+ top_level = enum_desc.full_name.count('.') == 0
238
+ if top_level:
239
+ file_name = enum_desc.file.name
240
+ package = enum_desc.file.package
241
+ for enum_value in enum_desc.values:
242
+ full_name = _NormalizeFullyQualifiedName(
243
+ '.'.join((package, enum_value.name)))
244
+ self._CheckConflictRegister(enum_value, full_name, file_name)
245
+ self._top_enum_values[full_name] = enum_value
246
+ self._AddFileDescriptor(enum_desc.file)
247
+
248
+ # Never call this method. It is for internal usage only.
249
+ def _AddServiceDescriptor(self, service_desc):
250
+ """Adds a ServiceDescriptor to the pool.
251
+
252
+ Args:
253
+ service_desc: A ServiceDescriptor.
254
+ """
255
+
256
+ if not isinstance(service_desc, descriptor.ServiceDescriptor):
257
+ raise TypeError('Expected instance of descriptor.ServiceDescriptor.')
258
+
259
+ self._CheckConflictRegister(service_desc, service_desc.full_name,
260
+ service_desc.file.name)
261
+ self._service_descriptors[service_desc.full_name] = service_desc
262
+
263
+ # Never call this method. It is for internal usage only.
264
+ def _AddExtensionDescriptor(self, extension):
265
+ """Adds a FieldDescriptor describing an extension to the pool.
266
+
267
+ Args:
268
+ extension: A FieldDescriptor.
269
+
270
+ Raises:
271
+ AssertionError: when another extension with the same number extends the
272
+ same message.
273
+ TypeError: when the specified extension is not a
274
+ descriptor.FieldDescriptor.
275
+ """
276
+ if not (isinstance(extension, descriptor.FieldDescriptor) and
277
+ extension.is_extension):
278
+ raise TypeError('Expected an extension descriptor.')
279
+
280
+ if extension.extension_scope is None:
281
+ self._CheckConflictRegister(
282
+ extension, extension.full_name, extension.file.name)
283
+ self._toplevel_extensions[extension.full_name] = extension
284
+
285
+ try:
286
+ existing_desc = self._extensions_by_number[
287
+ extension.containing_type][extension.number]
288
+ except KeyError:
289
+ pass
290
+ else:
291
+ if extension is not existing_desc:
292
+ raise AssertionError(
293
+ 'Extensions "%s" and "%s" both try to extend message type "%s" '
294
+ 'with field number %d.' %
295
+ (extension.full_name, existing_desc.full_name,
296
+ extension.containing_type.full_name, extension.number))
297
+
298
+ self._extensions_by_number[extension.containing_type][
299
+ extension.number] = extension
300
+ self._extensions_by_name[extension.containing_type][
301
+ extension.full_name] = extension
302
+
303
+ # Also register MessageSet extensions with the type name.
304
+ if _IsMessageSetExtension(extension):
305
+ self._extensions_by_name[extension.containing_type][
306
+ extension.message_type.full_name] = extension
307
+
308
+ if hasattr(extension.containing_type, '_concrete_class'):
309
+ python_message._AttachFieldHelpers(
310
+ extension.containing_type._concrete_class, extension)
311
+
312
+ # Never call this method. It is for internal usage only.
313
+ def _InternalAddFileDescriptor(self, file_desc):
314
+ """Adds a FileDescriptor to the pool, non-recursively.
315
+
316
+ If the FileDescriptor contains messages or enums, the caller must explicitly
317
+ register them.
318
+
319
+ Args:
320
+ file_desc: A FileDescriptor.
321
+ """
322
+
323
+ self._AddFileDescriptor(file_desc)
324
+
325
+ def _AddFileDescriptor(self, file_desc):
326
+ """Adds a FileDescriptor to the pool, non-recursively.
327
+
328
+ If the FileDescriptor contains messages or enums, the caller must explicitly
329
+ register them.
330
+
331
+ Args:
332
+ file_desc: A FileDescriptor.
333
+ """
334
+
335
+ if not isinstance(file_desc, descriptor.FileDescriptor):
336
+ raise TypeError('Expected instance of descriptor.FileDescriptor.')
337
+ self._file_descriptors[file_desc.name] = file_desc
338
+
339
+ def FindFileByName(self, file_name):
340
+ """Gets a FileDescriptor by file name.
341
+
342
+ Args:
343
+ file_name (str): The path to the file to get a descriptor for.
344
+
345
+ Returns:
346
+ FileDescriptor: The descriptor for the named file.
347
+
348
+ Raises:
349
+ KeyError: if the file cannot be found in the pool.
350
+ """
351
+
352
+ try:
353
+ return self._file_descriptors[file_name]
354
+ except KeyError:
355
+ pass
356
+
357
+ try:
358
+ file_proto = self._internal_db.FindFileByName(file_name)
359
+ except KeyError as error:
360
+ if self._descriptor_db:
361
+ file_proto = self._descriptor_db.FindFileByName(file_name)
362
+ else:
363
+ raise error
364
+ if not file_proto:
365
+ raise KeyError('Cannot find a file named %s' % file_name)
366
+ return self._ConvertFileProtoToFileDescriptor(file_proto)
367
+
368
+ def FindFileContainingSymbol(self, symbol):
369
+ """Gets the FileDescriptor for the file containing the specified symbol.
370
+
371
+ Args:
372
+ symbol (str): The name of the symbol to search for.
373
+
374
+ Returns:
375
+ FileDescriptor: Descriptor for the file that contains the specified
376
+ symbol.
377
+
378
+ Raises:
379
+ KeyError: if the file cannot be found in the pool.
380
+ """
381
+
382
+ symbol = _NormalizeFullyQualifiedName(symbol)
383
+ try:
384
+ return self._InternalFindFileContainingSymbol(symbol)
385
+ except KeyError:
386
+ pass
387
+
388
+ try:
389
+ # Try fallback database. Build and find again if possible.
390
+ self._FindFileContainingSymbolInDb(symbol)
391
+ return self._InternalFindFileContainingSymbol(symbol)
392
+ except KeyError:
393
+ raise KeyError('Cannot find a file containing %s' % symbol)
394
+
395
+ def _InternalFindFileContainingSymbol(self, symbol):
396
+ """Gets the already built FileDescriptor containing the specified symbol.
397
+
398
+ Args:
399
+ symbol (str): The name of the symbol to search for.
400
+
401
+ Returns:
402
+ FileDescriptor: Descriptor for the file that contains the specified
403
+ symbol.
404
+
405
+ Raises:
406
+ KeyError: if the file cannot be found in the pool.
407
+ """
408
+ try:
409
+ return self._descriptors[symbol].file
410
+ except KeyError:
411
+ pass
412
+
413
+ try:
414
+ return self._enum_descriptors[symbol].file
415
+ except KeyError:
416
+ pass
417
+
418
+ try:
419
+ return self._service_descriptors[symbol].file
420
+ except KeyError:
421
+ pass
422
+
423
+ try:
424
+ return self._top_enum_values[symbol].type.file
425
+ except KeyError:
426
+ pass
427
+
428
+ try:
429
+ return self._toplevel_extensions[symbol].file
430
+ except KeyError:
431
+ pass
432
+
433
+ # Try fields, enum values and nested extensions inside a message.
434
+ top_name, _, sub_name = symbol.rpartition('.')
435
+ try:
436
+ message = self.FindMessageTypeByName(top_name)
437
+ assert (sub_name in message.extensions_by_name or
438
+ sub_name in message.fields_by_name or
439
+ sub_name in message.enum_values_by_name)
440
+ return message.file
441
+ except (KeyError, AssertionError):
442
+ raise KeyError('Cannot find a file containing %s' % symbol)
443
+
444
+ def FindMessageTypeByName(self, full_name):
445
+ """Loads the named descriptor from the pool.
446
+
447
+ Args:
448
+ full_name (str): The full name of the descriptor to load.
449
+
450
+ Returns:
451
+ Descriptor: The descriptor for the named type.
452
+
453
+ Raises:
454
+ KeyError: if the message cannot be found in the pool.
455
+ """
456
+
457
+ full_name = _NormalizeFullyQualifiedName(full_name)
458
+ if full_name not in self._descriptors:
459
+ self._FindFileContainingSymbolInDb(full_name)
460
+ return self._descriptors[full_name]
461
+
462
+ def FindEnumTypeByName(self, full_name):
463
+ """Loads the named enum descriptor from the pool.
464
+
465
+ Args:
466
+ full_name (str): The full name of the enum descriptor to load.
467
+
468
+ Returns:
469
+ EnumDescriptor: The enum descriptor for the named type.
470
+
471
+ Raises:
472
+ KeyError: if the enum cannot be found in the pool.
473
+ """
474
+
475
+ full_name = _NormalizeFullyQualifiedName(full_name)
476
+ if full_name not in self._enum_descriptors:
477
+ self._FindFileContainingSymbolInDb(full_name)
478
+ return self._enum_descriptors[full_name]
479
+
480
+ def FindFieldByName(self, full_name):
481
+ """Loads the named field descriptor from the pool.
482
+
483
+ Args:
484
+ full_name (str): The full name of the field descriptor to load.
485
+
486
+ Returns:
487
+ FieldDescriptor: The field descriptor for the named field.
488
+
489
+ Raises:
490
+ KeyError: if the field cannot be found in the pool.
491
+ """
492
+ full_name = _NormalizeFullyQualifiedName(full_name)
493
+ message_name, _, field_name = full_name.rpartition('.')
494
+ message_descriptor = self.FindMessageTypeByName(message_name)
495
+ return message_descriptor.fields_by_name[field_name]
496
+
497
+ def FindOneofByName(self, full_name):
498
+ """Loads the named oneof descriptor from the pool.
499
+
500
+ Args:
501
+ full_name (str): The full name of the oneof descriptor to load.
502
+
503
+ Returns:
504
+ OneofDescriptor: The oneof descriptor for the named oneof.
505
+
506
+ Raises:
507
+ KeyError: if the oneof cannot be found in the pool.
508
+ """
509
+ full_name = _NormalizeFullyQualifiedName(full_name)
510
+ message_name, _, oneof_name = full_name.rpartition('.')
511
+ message_descriptor = self.FindMessageTypeByName(message_name)
512
+ return message_descriptor.oneofs_by_name[oneof_name]
513
+
514
+ def FindExtensionByName(self, full_name):
515
+ """Loads the named extension descriptor from the pool.
516
+
517
+ Args:
518
+ full_name (str): The full name of the extension descriptor to load.
519
+
520
+ Returns:
521
+ FieldDescriptor: The field descriptor for the named extension.
522
+
523
+ Raises:
524
+ KeyError: if the extension cannot be found in the pool.
525
+ """
526
+ full_name = _NormalizeFullyQualifiedName(full_name)
527
+ try:
528
+ # The proto compiler does not give any link between the FileDescriptor
529
+ # and top-level extensions unless the FileDescriptorProto is added to
530
+ # the DescriptorDatabase, but this can impact memory usage.
531
+ # So we registered these extensions by name explicitly.
532
+ return self._toplevel_extensions[full_name]
533
+ except KeyError:
534
+ pass
535
+ message_name, _, extension_name = full_name.rpartition('.')
536
+ try:
537
+ # Most extensions are nested inside a message.
538
+ scope = self.FindMessageTypeByName(message_name)
539
+ except KeyError:
540
+ # Some extensions are defined at file scope.
541
+ scope = self._FindFileContainingSymbolInDb(full_name)
542
+ return scope.extensions_by_name[extension_name]
543
+
544
+ def FindExtensionByNumber(self, message_descriptor, number):
545
+ """Gets the extension of the specified message with the specified number.
546
+
547
+ Extensions have to be registered to this pool by calling :func:`Add` or
548
+ :func:`AddExtensionDescriptor`.
549
+
550
+ Args:
551
+ message_descriptor (Descriptor): descriptor of the extended message.
552
+ number (int): Number of the extension field.
553
+
554
+ Returns:
555
+ FieldDescriptor: The descriptor for the extension.
556
+
557
+ Raises:
558
+ KeyError: when no extension with the given number is known for the
559
+ specified message.
560
+ """
561
+ try:
562
+ return self._extensions_by_number[message_descriptor][number]
563
+ except KeyError:
564
+ self._TryLoadExtensionFromDB(message_descriptor, number)
565
+ return self._extensions_by_number[message_descriptor][number]
566
+
567
+ def FindAllExtensions(self, message_descriptor):
568
+ """Gets all the known extensions of a given message.
569
+
570
+ Extensions have to be registered to this pool by build related
571
+ :func:`Add` or :func:`AddExtensionDescriptor`.
572
+
573
+ Args:
574
+ message_descriptor (Descriptor): Descriptor of the extended message.
575
+
576
+ Returns:
577
+ list[FieldDescriptor]: Field descriptors describing the extensions.
578
+ """
579
+ # Fallback to descriptor db if FindAllExtensionNumbers is provided.
580
+ if self._descriptor_db and hasattr(
581
+ self._descriptor_db, 'FindAllExtensionNumbers'):
582
+ full_name = message_descriptor.full_name
583
+ all_numbers = self._descriptor_db.FindAllExtensionNumbers(full_name)
584
+ for number in all_numbers:
585
+ if number in self._extensions_by_number[message_descriptor]:
586
+ continue
587
+ self._TryLoadExtensionFromDB(message_descriptor, number)
588
+
589
+ return list(self._extensions_by_number[message_descriptor].values())
590
+
591
+ def _TryLoadExtensionFromDB(self, message_descriptor, number):
592
+ """Try to Load extensions from descriptor db.
593
+
594
+ Args:
595
+ message_descriptor: descriptor of the extended message.
596
+ number: the extension number that needs to be loaded.
597
+ """
598
+ if not self._descriptor_db:
599
+ return
600
+ # Only supported when FindFileContainingExtension is provided.
601
+ if not hasattr(
602
+ self._descriptor_db, 'FindFileContainingExtension'):
603
+ return
604
+
605
+ full_name = message_descriptor.full_name
606
+ file_proto = self._descriptor_db.FindFileContainingExtension(
607
+ full_name, number)
608
+
609
+ if file_proto is None:
610
+ return
611
+
612
+ try:
613
+ self._ConvertFileProtoToFileDescriptor(file_proto)
614
+ except:
615
+ warn_msg = ('Unable to load proto file %s for extension number %d.' %
616
+ (file_proto.name, number))
617
+ warnings.warn(warn_msg, RuntimeWarning)
618
+
619
+ def FindServiceByName(self, full_name):
620
+ """Loads the named service descriptor from the pool.
621
+
622
+ Args:
623
+ full_name (str): The full name of the service descriptor to load.
624
+
625
+ Returns:
626
+ ServiceDescriptor: The service descriptor for the named service.
627
+
628
+ Raises:
629
+ KeyError: if the service cannot be found in the pool.
630
+ """
631
+ full_name = _NormalizeFullyQualifiedName(full_name)
632
+ if full_name not in self._service_descriptors:
633
+ self._FindFileContainingSymbolInDb(full_name)
634
+ return self._service_descriptors[full_name]
635
+
636
+ def FindMethodByName(self, full_name):
637
+ """Loads the named service method descriptor from the pool.
638
+
639
+ Args:
640
+ full_name (str): The full name of the method descriptor to load.
641
+
642
+ Returns:
643
+ MethodDescriptor: The method descriptor for the service method.
644
+
645
+ Raises:
646
+ KeyError: if the method cannot be found in the pool.
647
+ """
648
+ full_name = _NormalizeFullyQualifiedName(full_name)
649
+ service_name, _, method_name = full_name.rpartition('.')
650
+ service_descriptor = self.FindServiceByName(service_name)
651
+ return service_descriptor.methods_by_name[method_name]
652
+
653
+ def SetFeatureSetDefaults(self, defaults):
654
+ """Sets the default feature mappings used during the build.
655
+
656
+ Args:
657
+ defaults: a FeatureSetDefaults message containing the new mappings.
658
+ """
659
+ if self._edition_defaults is not None:
660
+ raise ValueError(
661
+ "Feature set defaults can't be changed once the pool has started"
662
+ ' building!'
663
+ )
664
+
665
+ # pylint: disable=g-import-not-at-top
666
+ from google.protobuf import descriptor_pb2
667
+
668
+ if not isinstance(defaults, descriptor_pb2.FeatureSetDefaults):
669
+ raise TypeError('SetFeatureSetDefaults called with invalid type')
670
+
671
+
672
+ if defaults.minimum_edition > defaults.maximum_edition:
673
+ raise ValueError(
674
+ 'Invalid edition range %s to %s'
675
+ % (
676
+ descriptor_pb2.Edition.Name(defaults.minimum_edition),
677
+ descriptor_pb2.Edition.Name(defaults.maximum_edition),
678
+ )
679
+ )
680
+
681
+ prev_edition = descriptor_pb2.Edition.EDITION_UNKNOWN
682
+ for d in defaults.defaults:
683
+ if d.edition == descriptor_pb2.Edition.EDITION_UNKNOWN:
684
+ raise ValueError('Invalid edition EDITION_UNKNOWN specified')
685
+ if prev_edition >= d.edition:
686
+ raise ValueError(
687
+ 'Feature set defaults are not strictly increasing. %s is greater'
688
+ ' than or equal to %s'
689
+ % (
690
+ descriptor_pb2.Edition.Name(prev_edition),
691
+ descriptor_pb2.Edition.Name(d.edition),
692
+ )
693
+ )
694
+ prev_edition = d.edition
695
+ self._edition_defaults = defaults
696
+
697
+ def _CreateDefaultFeatures(self, edition):
698
+ """Creates a FeatureSet message with defaults for a specific edition.
699
+
700
+ Args:
701
+ edition: the edition to generate defaults for.
702
+
703
+ Returns:
704
+ A FeatureSet message with defaults for a specific edition.
705
+ """
706
+ # pylint: disable=g-import-not-at-top
707
+ from google.protobuf import descriptor_pb2
708
+
709
+ with _edition_defaults_lock:
710
+ if not self._edition_defaults:
711
+ self._edition_defaults = descriptor_pb2.FeatureSetDefaults()
712
+ self._edition_defaults.ParseFromString(
713
+ self._serialized_edition_defaults
714
+ )
715
+
716
+ if edition < self._edition_defaults.minimum_edition:
717
+ raise TypeError(
718
+ 'Edition %s is earlier than the minimum supported edition %s!'
719
+ % (
720
+ descriptor_pb2.Edition.Name(edition),
721
+ descriptor_pb2.Edition.Name(
722
+ self._edition_defaults.minimum_edition
723
+ ),
724
+ )
725
+ )
726
+ if edition > self._edition_defaults.maximum_edition:
727
+ raise TypeError(
728
+ 'Edition %s is later than the maximum supported edition %s!'
729
+ % (
730
+ descriptor_pb2.Edition.Name(edition),
731
+ descriptor_pb2.Edition.Name(
732
+ self._edition_defaults.maximum_edition
733
+ ),
734
+ )
735
+ )
736
+ found = None
737
+ for d in self._edition_defaults.defaults:
738
+ if d.edition > edition:
739
+ break
740
+ found = d
741
+ if found is None:
742
+ raise TypeError(
743
+ 'No valid default found for edition %s!'
744
+ % descriptor_pb2.Edition.Name(edition)
745
+ )
746
+
747
+ defaults = descriptor_pb2.FeatureSet()
748
+ defaults.CopyFrom(found.fixed_features)
749
+ defaults.MergeFrom(found.overridable_features)
750
+ return defaults
751
+
752
+ def _InternFeatures(self, features):
753
+ serialized = features.SerializeToString()
754
+ with _edition_defaults_lock:
755
+ cached = self._feature_cache.get(serialized)
756
+ if cached is None:
757
+ self._feature_cache[serialized] = features
758
+ cached = features
759
+ return cached
760
+
761
+ def _FindFileContainingSymbolInDb(self, symbol):
762
+ """Finds the file in descriptor DB containing the specified symbol.
763
+
764
+ Args:
765
+ symbol (str): The name of the symbol to search for.
766
+
767
+ Returns:
768
+ FileDescriptor: The file that contains the specified symbol.
769
+
770
+ Raises:
771
+ KeyError: if the file cannot be found in the descriptor database.
772
+ """
773
+ try:
774
+ file_proto = self._internal_db.FindFileContainingSymbol(symbol)
775
+ except KeyError as error:
776
+ if self._descriptor_db:
777
+ file_proto = self._descriptor_db.FindFileContainingSymbol(symbol)
778
+ else:
779
+ raise error
780
+ if not file_proto:
781
+ raise KeyError('Cannot find a file containing %s' % symbol)
782
+ return self._ConvertFileProtoToFileDescriptor(file_proto)
783
+
784
+ def _ConvertFileProtoToFileDescriptor(self, file_proto):
785
+ """Creates a FileDescriptor from a proto or returns a cached copy.
786
+
787
+ This method also has the side effect of loading all the symbols found in
788
+ the file into the appropriate dictionaries in the pool.
789
+
790
+ Args:
791
+ file_proto: The proto to convert.
792
+
793
+ Returns:
794
+ A FileDescriptor matching the passed in proto.
795
+ """
796
+ if file_proto.name not in self._file_descriptors:
797
+ built_deps = list(self._GetDeps(file_proto.dependency))
798
+ direct_deps = [self.FindFileByName(n) for n in file_proto.dependency]
799
+ public_deps = [direct_deps[i] for i in file_proto.public_dependency]
800
+
801
+ # pylint: disable=g-import-not-at-top
802
+ from google.protobuf import descriptor_pb2
803
+
804
+ file_descriptor = descriptor.FileDescriptor(
805
+ pool=self,
806
+ name=file_proto.name,
807
+ package=file_proto.package,
808
+ syntax=file_proto.syntax,
809
+ edition=descriptor_pb2.Edition.Name(file_proto.edition),
810
+ options=_OptionsOrNone(file_proto),
811
+ serialized_pb=file_proto.SerializeToString(),
812
+ dependencies=direct_deps,
813
+ public_dependencies=public_deps,
814
+ # pylint: disable=protected-access
815
+ create_key=descriptor._internal_create_key,
816
+ )
817
+ scope = {}
818
+
819
+ # This loop extracts all the message and enum types from all the
820
+ # dependencies of the file_proto. This is necessary to create the
821
+ # scope of available message types when defining the passed in
822
+ # file proto.
823
+ for dependency in built_deps:
824
+ scope.update(self._ExtractSymbols(
825
+ dependency.message_types_by_name.values()))
826
+ scope.update((_PrefixWithDot(enum.full_name), enum)
827
+ for enum in dependency.enum_types_by_name.values())
828
+
829
+ for message_type in file_proto.message_type:
830
+ message_desc = self._ConvertMessageDescriptor(
831
+ message_type, file_proto.package, file_descriptor, scope,
832
+ file_proto.syntax)
833
+ file_descriptor.message_types_by_name[message_desc.name] = (
834
+ message_desc)
835
+
836
+ for enum_type in file_proto.enum_type:
837
+ file_descriptor.enum_types_by_name[enum_type.name] = (
838
+ self._ConvertEnumDescriptor(enum_type, file_proto.package,
839
+ file_descriptor, None, scope, True))
840
+
841
+ for index, extension_proto in enumerate(file_proto.extension):
842
+ extension_desc = self._MakeFieldDescriptor(
843
+ extension_proto, file_proto.package, index, file_descriptor,
844
+ is_extension=True)
845
+ extension_desc.containing_type = self._GetTypeFromScope(
846
+ file_descriptor.package, extension_proto.extendee, scope)
847
+ self._SetFieldType(extension_proto, extension_desc,
848
+ file_descriptor.package, scope)
849
+ file_descriptor.extensions_by_name[extension_desc.name] = (
850
+ extension_desc)
851
+
852
+ for desc_proto in file_proto.message_type:
853
+ self._SetAllFieldTypes(file_proto.package, desc_proto, scope)
854
+
855
+ if file_proto.package:
856
+ desc_proto_prefix = _PrefixWithDot(file_proto.package)
857
+ else:
858
+ desc_proto_prefix = ''
859
+
860
+ for desc_proto in file_proto.message_type:
861
+ desc = self._GetTypeFromScope(
862
+ desc_proto_prefix, desc_proto.name, scope)
863
+ file_descriptor.message_types_by_name[desc_proto.name] = desc
864
+
865
+ for index, service_proto in enumerate(file_proto.service):
866
+ file_descriptor.services_by_name[service_proto.name] = (
867
+ self._MakeServiceDescriptor(service_proto, index, scope,
868
+ file_proto.package, file_descriptor))
869
+
870
+ self._file_descriptors[file_proto.name] = file_descriptor
871
+
872
+ # Add extensions to the pool
873
+ def AddExtensionForNested(message_type):
874
+ for nested in message_type.nested_types:
875
+ AddExtensionForNested(nested)
876
+ for extension in message_type.extensions:
877
+ self._AddExtensionDescriptor(extension)
878
+
879
+ file_desc = self._file_descriptors[file_proto.name]
880
+ for extension in file_desc.extensions_by_name.values():
881
+ self._AddExtensionDescriptor(extension)
882
+ for message_type in file_desc.message_types_by_name.values():
883
+ AddExtensionForNested(message_type)
884
+
885
+ return file_desc
886
+
887
+ def _ConvertMessageDescriptor(self, desc_proto, package=None, file_desc=None,
888
+ scope=None, syntax=None):
889
+ """Adds the proto to the pool in the specified package.
890
+
891
+ Args:
892
+ desc_proto: The descriptor_pb2.DescriptorProto protobuf message.
893
+ package: The package the proto should be located in.
894
+ file_desc: The file containing this message.
895
+ scope: Dict mapping short and full symbols to message and enum types.
896
+ syntax: string indicating syntax of the file ("proto2" or "proto3")
897
+
898
+ Returns:
899
+ The added descriptor.
900
+ """
901
+
902
+ if package:
903
+ desc_name = '.'.join((package, desc_proto.name))
904
+ else:
905
+ desc_name = desc_proto.name
906
+
907
+ if file_desc is None:
908
+ file_name = None
909
+ else:
910
+ file_name = file_desc.name
911
+
912
+ if scope is None:
913
+ scope = {}
914
+
915
+ nested = [
916
+ self._ConvertMessageDescriptor(
917
+ nested, desc_name, file_desc, scope, syntax)
918
+ for nested in desc_proto.nested_type]
919
+ enums = [
920
+ self._ConvertEnumDescriptor(enum, desc_name, file_desc, None,
921
+ scope, False)
922
+ for enum in desc_proto.enum_type]
923
+ fields = [self._MakeFieldDescriptor(field, desc_name, index, file_desc)
924
+ for index, field in enumerate(desc_proto.field)]
925
+ extensions = [
926
+ self._MakeFieldDescriptor(extension, desc_name, index, file_desc,
927
+ is_extension=True)
928
+ for index, extension in enumerate(desc_proto.extension)]
929
+ oneofs = [
930
+ # pylint: disable=g-complex-comprehension
931
+ descriptor.OneofDescriptor(
932
+ desc.name,
933
+ '.'.join((desc_name, desc.name)),
934
+ index,
935
+ None,
936
+ [],
937
+ _OptionsOrNone(desc),
938
+ # pylint: disable=protected-access
939
+ create_key=descriptor._internal_create_key)
940
+ for index, desc in enumerate(desc_proto.oneof_decl)
941
+ ]
942
+ extension_ranges = [(r.start, r.end) for r in desc_proto.extension_range]
943
+ if extension_ranges:
944
+ is_extendable = True
945
+ else:
946
+ is_extendable = False
947
+ desc = descriptor.Descriptor(
948
+ name=desc_proto.name,
949
+ full_name=desc_name,
950
+ filename=file_name,
951
+ containing_type=None,
952
+ fields=fields,
953
+ oneofs=oneofs,
954
+ nested_types=nested,
955
+ enum_types=enums,
956
+ extensions=extensions,
957
+ options=_OptionsOrNone(desc_proto),
958
+ is_extendable=is_extendable,
959
+ extension_ranges=extension_ranges,
960
+ file=file_desc,
961
+ serialized_start=None,
962
+ serialized_end=None,
963
+ is_map_entry=desc_proto.options.map_entry,
964
+ # pylint: disable=protected-access
965
+ create_key=descriptor._internal_create_key,
966
+ )
967
+ for nested in desc.nested_types:
968
+ nested.containing_type = desc
969
+ for enum in desc.enum_types:
970
+ enum.containing_type = desc
971
+ for field_index, field_desc in enumerate(desc_proto.field):
972
+ if field_desc.HasField('oneof_index'):
973
+ oneof_index = field_desc.oneof_index
974
+ oneofs[oneof_index].fields.append(fields[field_index])
975
+ fields[field_index].containing_oneof = oneofs[oneof_index]
976
+
977
+ scope[_PrefixWithDot(desc_name)] = desc
978
+ self._CheckConflictRegister(desc, desc.full_name, desc.file.name)
979
+ self._descriptors[desc_name] = desc
980
+ return desc
981
+
982
+ def _ConvertEnumDescriptor(self, enum_proto, package=None, file_desc=None,
983
+ containing_type=None, scope=None, top_level=False):
984
+ """Make a protobuf EnumDescriptor given an EnumDescriptorProto protobuf.
985
+
986
+ Args:
987
+ enum_proto: The descriptor_pb2.EnumDescriptorProto protobuf message.
988
+ package: Optional package name for the new message EnumDescriptor.
989
+ file_desc: The file containing the enum descriptor.
990
+ containing_type: The type containing this enum.
991
+ scope: Scope containing available types.
992
+ top_level: If True, the enum is a top level symbol. If False, the enum
993
+ is defined inside a message.
994
+
995
+ Returns:
996
+ The added descriptor
997
+ """
998
+
999
+ if package:
1000
+ enum_name = '.'.join((package, enum_proto.name))
1001
+ else:
1002
+ enum_name = enum_proto.name
1003
+
1004
+ if file_desc is None:
1005
+ file_name = None
1006
+ else:
1007
+ file_name = file_desc.name
1008
+
1009
+ values = [self._MakeEnumValueDescriptor(value, index)
1010
+ for index, value in enumerate(enum_proto.value)]
1011
+ desc = descriptor.EnumDescriptor(name=enum_proto.name,
1012
+ full_name=enum_name,
1013
+ filename=file_name,
1014
+ file=file_desc,
1015
+ values=values,
1016
+ containing_type=containing_type,
1017
+ options=_OptionsOrNone(enum_proto),
1018
+ # pylint: disable=protected-access
1019
+ create_key=descriptor._internal_create_key)
1020
+ scope['.%s' % enum_name] = desc
1021
+ self._CheckConflictRegister(desc, desc.full_name, desc.file.name)
1022
+ self._enum_descriptors[enum_name] = desc
1023
+
1024
+ # Add top level enum values.
1025
+ if top_level:
1026
+ for value in values:
1027
+ full_name = _NormalizeFullyQualifiedName(
1028
+ '.'.join((package, value.name)))
1029
+ self._CheckConflictRegister(value, full_name, file_name)
1030
+ self._top_enum_values[full_name] = value
1031
+
1032
+ return desc
1033
+
1034
+ def _MakeFieldDescriptor(self, field_proto, message_name, index,
1035
+ file_desc, is_extension=False):
1036
+ """Creates a field descriptor from a FieldDescriptorProto.
1037
+
1038
+ For message and enum type fields, this method will do a look up
1039
+ in the pool for the appropriate descriptor for that type. If it
1040
+ is unavailable, it will fall back to the _source function to
1041
+ create it. If this type is still unavailable, construction will
1042
+ fail.
1043
+
1044
+ Args:
1045
+ field_proto: The proto describing the field.
1046
+ message_name: The name of the containing message.
1047
+ index: Index of the field
1048
+ file_desc: The file containing the field descriptor.
1049
+ is_extension: Indication that this field is for an extension.
1050
+
1051
+ Returns:
1052
+ An initialized FieldDescriptor object
1053
+ """
1054
+
1055
+ if message_name:
1056
+ full_name = '.'.join((message_name, field_proto.name))
1057
+ else:
1058
+ full_name = field_proto.name
1059
+
1060
+ if field_proto.json_name:
1061
+ json_name = field_proto.json_name
1062
+ else:
1063
+ json_name = None
1064
+
1065
+ return descriptor.FieldDescriptor(
1066
+ name=field_proto.name,
1067
+ full_name=full_name,
1068
+ index=index,
1069
+ number=field_proto.number,
1070
+ type=field_proto.type,
1071
+ cpp_type=None,
1072
+ message_type=None,
1073
+ enum_type=None,
1074
+ containing_type=None,
1075
+ label=field_proto.label,
1076
+ has_default_value=False,
1077
+ default_value=None,
1078
+ is_extension=is_extension,
1079
+ extension_scope=None,
1080
+ options=_OptionsOrNone(field_proto),
1081
+ json_name=json_name,
1082
+ file=file_desc,
1083
+ # pylint: disable=protected-access
1084
+ create_key=descriptor._internal_create_key)
1085
+
1086
+ def _SetAllFieldTypes(self, package, desc_proto, scope):
1087
+ """Sets all the descriptor's fields's types.
1088
+
1089
+ This method also sets the containing types on any extensions.
1090
+
1091
+ Args:
1092
+ package: The current package of desc_proto.
1093
+ desc_proto: The message descriptor to update.
1094
+ scope: Enclosing scope of available types.
1095
+ """
1096
+
1097
+ package = _PrefixWithDot(package)
1098
+
1099
+ main_desc = self._GetTypeFromScope(package, desc_proto.name, scope)
1100
+
1101
+ if package == '.':
1102
+ nested_package = _PrefixWithDot(desc_proto.name)
1103
+ else:
1104
+ nested_package = '.'.join([package, desc_proto.name])
1105
+
1106
+ for field_proto, field_desc in zip(desc_proto.field, main_desc.fields):
1107
+ self._SetFieldType(field_proto, field_desc, nested_package, scope)
1108
+
1109
+ for extension_proto, extension_desc in (
1110
+ zip(desc_proto.extension, main_desc.extensions)):
1111
+ extension_desc.containing_type = self._GetTypeFromScope(
1112
+ nested_package, extension_proto.extendee, scope)
1113
+ self._SetFieldType(extension_proto, extension_desc, nested_package, scope)
1114
+
1115
+ for nested_type in desc_proto.nested_type:
1116
+ self._SetAllFieldTypes(nested_package, nested_type, scope)
1117
+
1118
+ def _SetFieldType(self, field_proto, field_desc, package, scope):
1119
+ """Sets the field's type, cpp_type, message_type and enum_type.
1120
+
1121
+ Args:
1122
+ field_proto: Data about the field in proto format.
1123
+ field_desc: The descriptor to modify.
1124
+ package: The package the field's container is in.
1125
+ scope: Enclosing scope of available types.
1126
+ """
1127
+ if field_proto.type_name:
1128
+ desc = self._GetTypeFromScope(package, field_proto.type_name, scope)
1129
+ else:
1130
+ desc = None
1131
+
1132
+ if not field_proto.HasField('type'):
1133
+ if isinstance(desc, descriptor.Descriptor):
1134
+ field_proto.type = descriptor.FieldDescriptor.TYPE_MESSAGE
1135
+ else:
1136
+ field_proto.type = descriptor.FieldDescriptor.TYPE_ENUM
1137
+
1138
+ field_desc.cpp_type = descriptor.FieldDescriptor.ProtoTypeToCppProtoType(
1139
+ field_proto.type)
1140
+
1141
+ if (field_proto.type == descriptor.FieldDescriptor.TYPE_MESSAGE
1142
+ or field_proto.type == descriptor.FieldDescriptor.TYPE_GROUP):
1143
+ field_desc.message_type = desc
1144
+
1145
+ if field_proto.type == descriptor.FieldDescriptor.TYPE_ENUM:
1146
+ field_desc.enum_type = desc
1147
+
1148
+ if field_proto.label == descriptor.FieldDescriptor.LABEL_REPEATED:
1149
+ field_desc.has_default_value = False
1150
+ field_desc.default_value = []
1151
+ elif field_proto.HasField('default_value'):
1152
+ field_desc.has_default_value = True
1153
+ if (field_proto.type == descriptor.FieldDescriptor.TYPE_DOUBLE or
1154
+ field_proto.type == descriptor.FieldDescriptor.TYPE_FLOAT):
1155
+ field_desc.default_value = float(field_proto.default_value)
1156
+ elif field_proto.type == descriptor.FieldDescriptor.TYPE_STRING:
1157
+ field_desc.default_value = field_proto.default_value
1158
+ elif field_proto.type == descriptor.FieldDescriptor.TYPE_BOOL:
1159
+ field_desc.default_value = field_proto.default_value.lower() == 'true'
1160
+ elif field_proto.type == descriptor.FieldDescriptor.TYPE_ENUM:
1161
+ field_desc.default_value = field_desc.enum_type.values_by_name[
1162
+ field_proto.default_value].number
1163
+ elif field_proto.type == descriptor.FieldDescriptor.TYPE_BYTES:
1164
+ field_desc.default_value = text_encoding.CUnescape(
1165
+ field_proto.default_value)
1166
+ elif field_proto.type == descriptor.FieldDescriptor.TYPE_MESSAGE:
1167
+ field_desc.default_value = None
1168
+ else:
1169
+ # All other types are of the "int" type.
1170
+ field_desc.default_value = int(field_proto.default_value)
1171
+ else:
1172
+ field_desc.has_default_value = False
1173
+ if (field_proto.type == descriptor.FieldDescriptor.TYPE_DOUBLE or
1174
+ field_proto.type == descriptor.FieldDescriptor.TYPE_FLOAT):
1175
+ field_desc.default_value = 0.0
1176
+ elif field_proto.type == descriptor.FieldDescriptor.TYPE_STRING:
1177
+ field_desc.default_value = u''
1178
+ elif field_proto.type == descriptor.FieldDescriptor.TYPE_BOOL:
1179
+ field_desc.default_value = False
1180
+ elif field_proto.type == descriptor.FieldDescriptor.TYPE_ENUM:
1181
+ field_desc.default_value = field_desc.enum_type.values[0].number
1182
+ elif field_proto.type == descriptor.FieldDescriptor.TYPE_BYTES:
1183
+ field_desc.default_value = b''
1184
+ elif field_proto.type == descriptor.FieldDescriptor.TYPE_MESSAGE:
1185
+ field_desc.default_value = None
1186
+ elif field_proto.type == descriptor.FieldDescriptor.TYPE_GROUP:
1187
+ field_desc.default_value = None
1188
+ else:
1189
+ # All other types are of the "int" type.
1190
+ field_desc.default_value = 0
1191
+
1192
+ field_desc.type = field_proto.type
1193
+
1194
+ def _MakeEnumValueDescriptor(self, value_proto, index):
1195
+ """Creates a enum value descriptor object from a enum value proto.
1196
+
1197
+ Args:
1198
+ value_proto: The proto describing the enum value.
1199
+ index: The index of the enum value.
1200
+
1201
+ Returns:
1202
+ An initialized EnumValueDescriptor object.
1203
+ """
1204
+
1205
+ return descriptor.EnumValueDescriptor(
1206
+ name=value_proto.name,
1207
+ index=index,
1208
+ number=value_proto.number,
1209
+ options=_OptionsOrNone(value_proto),
1210
+ type=None,
1211
+ # pylint: disable=protected-access
1212
+ create_key=descriptor._internal_create_key)
1213
+
1214
+ def _MakeServiceDescriptor(self, service_proto, service_index, scope,
1215
+ package, file_desc):
1216
+ """Make a protobuf ServiceDescriptor given a ServiceDescriptorProto.
1217
+
1218
+ Args:
1219
+ service_proto: The descriptor_pb2.ServiceDescriptorProto protobuf message.
1220
+ service_index: The index of the service in the File.
1221
+ scope: Dict mapping short and full symbols to message and enum types.
1222
+ package: Optional package name for the new message EnumDescriptor.
1223
+ file_desc: The file containing the service descriptor.
1224
+
1225
+ Returns:
1226
+ The added descriptor.
1227
+ """
1228
+
1229
+ if package:
1230
+ service_name = '.'.join((package, service_proto.name))
1231
+ else:
1232
+ service_name = service_proto.name
1233
+
1234
+ methods = [self._MakeMethodDescriptor(method_proto, service_name, package,
1235
+ scope, index)
1236
+ for index, method_proto in enumerate(service_proto.method)]
1237
+ desc = descriptor.ServiceDescriptor(
1238
+ name=service_proto.name,
1239
+ full_name=service_name,
1240
+ index=service_index,
1241
+ methods=methods,
1242
+ options=_OptionsOrNone(service_proto),
1243
+ file=file_desc,
1244
+ # pylint: disable=protected-access
1245
+ create_key=descriptor._internal_create_key)
1246
+ self._CheckConflictRegister(desc, desc.full_name, desc.file.name)
1247
+ self._service_descriptors[service_name] = desc
1248
+ return desc
1249
+
1250
+ def _MakeMethodDescriptor(self, method_proto, service_name, package, scope,
1251
+ index):
1252
+ """Creates a method descriptor from a MethodDescriptorProto.
1253
+
1254
+ Args:
1255
+ method_proto: The proto describing the method.
1256
+ service_name: The name of the containing service.
1257
+ package: Optional package name to look up for types.
1258
+ scope: Scope containing available types.
1259
+ index: Index of the method in the service.
1260
+
1261
+ Returns:
1262
+ An initialized MethodDescriptor object.
1263
+ """
1264
+ full_name = '.'.join((service_name, method_proto.name))
1265
+ input_type = self._GetTypeFromScope(
1266
+ package, method_proto.input_type, scope)
1267
+ output_type = self._GetTypeFromScope(
1268
+ package, method_proto.output_type, scope)
1269
+ return descriptor.MethodDescriptor(
1270
+ name=method_proto.name,
1271
+ full_name=full_name,
1272
+ index=index,
1273
+ containing_service=None,
1274
+ input_type=input_type,
1275
+ output_type=output_type,
1276
+ client_streaming=method_proto.client_streaming,
1277
+ server_streaming=method_proto.server_streaming,
1278
+ options=_OptionsOrNone(method_proto),
1279
+ # pylint: disable=protected-access
1280
+ create_key=descriptor._internal_create_key)
1281
+
1282
+ def _ExtractSymbols(self, descriptors):
1283
+ """Pulls out all the symbols from descriptor protos.
1284
+
1285
+ Args:
1286
+ descriptors: The messages to extract descriptors from.
1287
+ Yields:
1288
+ A two element tuple of the type name and descriptor object.
1289
+ """
1290
+
1291
+ for desc in descriptors:
1292
+ yield (_PrefixWithDot(desc.full_name), desc)
1293
+ for symbol in self._ExtractSymbols(desc.nested_types):
1294
+ yield symbol
1295
+ for enum in desc.enum_types:
1296
+ yield (_PrefixWithDot(enum.full_name), enum)
1297
+
1298
+ def _GetDeps(self, dependencies, visited=None):
1299
+ """Recursively finds dependencies for file protos.
1300
+
1301
+ Args:
1302
+ dependencies: The names of the files being depended on.
1303
+ visited: The names of files already found.
1304
+
1305
+ Yields:
1306
+ Each direct and indirect dependency.
1307
+ """
1308
+
1309
+ visited = visited or set()
1310
+ for dependency in dependencies:
1311
+ if dependency not in visited:
1312
+ visited.add(dependency)
1313
+ dep_desc = self.FindFileByName(dependency)
1314
+ yield dep_desc
1315
+ public_files = [d.name for d in dep_desc.public_dependencies]
1316
+ yield from self._GetDeps(public_files, visited)
1317
+
1318
+ def _GetTypeFromScope(self, package, type_name, scope):
1319
+ """Finds a given type name in the current scope.
1320
+
1321
+ Args:
1322
+ package: The package the proto should be located in.
1323
+ type_name: The name of the type to be found in the scope.
1324
+ scope: Dict mapping short and full symbols to message and enum types.
1325
+
1326
+ Returns:
1327
+ The descriptor for the requested type.
1328
+ """
1329
+ if type_name not in scope:
1330
+ components = _PrefixWithDot(package).split('.')
1331
+ while components:
1332
+ possible_match = '.'.join(components + [type_name])
1333
+ if possible_match in scope:
1334
+ type_name = possible_match
1335
+ break
1336
+ else:
1337
+ components.pop(-1)
1338
+ return scope[type_name]
1339
+
1340
+
1341
+ def _PrefixWithDot(name):
1342
+ return name if name.startswith('.') else '.%s' % name
1343
+
1344
+
1345
+ if _USE_C_DESCRIPTORS:
1346
+ # TODO: This pool could be constructed from Python code, when we
1347
+ # support a flag like 'use_cpp_generated_pool=True'.
1348
+ # pylint: disable=protected-access
1349
+ _DEFAULT = descriptor._message.default_pool
1350
+ else:
1351
+ _DEFAULT = DescriptorPool()
1352
+
1353
+
1354
+ def Default():
1355
+ return _DEFAULT
.venv/lib/python3.11/site-packages/google/protobuf/duration.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Protocol Buffers - Google's data interchange format
2
+ # Copyright 2008 Google Inc. All rights reserved.
3
+ #
4
+ # Use of this source code is governed by a BSD-style
5
+ # license that can be found in the LICENSE file or at
6
+ # https://developers.google.com/open-source/licenses/bsd
7
+
8
+ """Contains the Duration helper APIs."""
9
+
10
+ import datetime
11
+
12
+ from google.protobuf.duration_pb2 import Duration
13
+
14
+
15
+ def from_json_string(value: str) -> Duration:
16
+ """Converts a string to Duration.
17
+
18
+ Args:
19
+ value: A string to be converted. The string must end with 's'. Any
20
+ fractional digits (or none) are accepted as long as they fit into
21
+ precision. For example: "1s", "1.01s", "1.0000001s", "-3.100s"
22
+
23
+ Raises:
24
+ ValueError: On parsing problems.
25
+ """
26
+ duration = Duration()
27
+ duration.FromJsonString(value)
28
+ return duration
29
+
30
+
31
+ def from_microseconds(micros: float) -> Duration:
32
+ """Converts microseconds to Duration."""
33
+ duration = Duration()
34
+ duration.FromMicroseconds(micros)
35
+ return duration
36
+
37
+
38
+ def from_milliseconds(millis: float) -> Duration:
39
+ """Converts milliseconds to Duration."""
40
+ duration = Duration()
41
+ duration.FromMilliseconds(millis)
42
+ return duration
43
+
44
+
45
+ def from_nanoseconds(nanos: float) -> Duration:
46
+ """Converts nanoseconds to Duration."""
47
+ duration = Duration()
48
+ duration.FromNanoseconds(nanos)
49
+ return duration
50
+
51
+
52
+ def from_seconds(seconds: float) -> Duration:
53
+ """Converts seconds to Duration."""
54
+ duration = Duration()
55
+ duration.FromSeconds(seconds)
56
+ return duration
57
+
58
+
59
+ def from_timedelta(td: datetime.timedelta) -> Duration:
60
+ """Converts timedelta to Duration."""
61
+ duration = Duration()
62
+ duration.FromTimedelta(td)
63
+ return duration
64
+
65
+
66
+ def to_json_string(duration: Duration) -> str:
67
+ """Converts Duration to string format.
68
+
69
+ Returns:
70
+ A string converted from self. The string format will contains
71
+ 3, 6, or 9 fractional digits depending on the precision required to
72
+ represent the exact Duration value. For example: "1s", "1.010s",
73
+ "1.000000100s", "-3.100s"
74
+ """
75
+ return duration.ToJsonString()
76
+
77
+
78
+ def to_microseconds(duration: Duration) -> int:
79
+ """Converts a Duration to microseconds."""
80
+ return duration.ToMicroseconds()
81
+
82
+
83
+ def to_milliseconds(duration: Duration) -> int:
84
+ """Converts a Duration to milliseconds."""
85
+ return duration.ToMilliseconds()
86
+
87
+
88
+ def to_nanoseconds(duration: Duration) -> int:
89
+ """Converts a Duration to nanoseconds."""
90
+ return duration.ToNanoseconds()
91
+
92
+
93
+ def to_seconds(duration: Duration) -> int:
94
+ """Converts a Duration to seconds."""
95
+ return duration.ToSeconds()
96
+
97
+
98
+ def to_timedelta(duration: Duration) -> datetime.timedelta:
99
+ """Converts Duration to timedelta."""
100
+ return duration.ToTimedelta()
.venv/lib/python3.11/site-packages/google/protobuf/duration_pb2.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Generated by the protocol buffer compiler. DO NOT EDIT!
3
+ # NO CHECKED-IN PROTOBUF GENCODE
4
+ # source: google/protobuf/duration.proto
5
+ # Protobuf Python Version: 5.29.3
6
+ """Generated protocol buffer code."""
7
+ from google.protobuf import descriptor as _descriptor
8
+ from google.protobuf import descriptor_pool as _descriptor_pool
9
+ from google.protobuf import runtime_version as _runtime_version
10
+ from google.protobuf import symbol_database as _symbol_database
11
+ from google.protobuf.internal import builder as _builder
12
+ _runtime_version.ValidateProtobufRuntimeVersion(
13
+ _runtime_version.Domain.PUBLIC,
14
+ 5,
15
+ 29,
16
+ 3,
17
+ '',
18
+ 'google/protobuf/duration.proto'
19
+ )
20
+ # @@protoc_insertion_point(imports)
21
+
22
+ _sym_db = _symbol_database.Default()
23
+
24
+
25
+
26
+
27
+ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1egoogle/protobuf/duration.proto\x12\x0fgoogle.protobuf\":\n\x08\x44uration\x12\x18\n\x07seconds\x18\x01 \x01(\x03R\x07seconds\x12\x14\n\x05nanos\x18\x02 \x01(\x05R\x05nanosB\x83\x01\n\x13\x63om.google.protobufB\rDurationProtoP\x01Z1google.golang.org/protobuf/types/known/durationpb\xf8\x01\x01\xa2\x02\x03GPB\xaa\x02\x1eGoogle.Protobuf.WellKnownTypesb\x06proto3')
28
+
29
+ _globals = globals()
30
+ _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
31
+ _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'google.protobuf.duration_pb2', _globals)
32
+ if not _descriptor._USE_C_DESCRIPTORS:
33
+ _globals['DESCRIPTOR']._loaded_options = None
34
+ _globals['DESCRIPTOR']._serialized_options = b'\n\023com.google.protobufB\rDurationProtoP\001Z1google.golang.org/protobuf/types/known/durationpb\370\001\001\242\002\003GPB\252\002\036Google.Protobuf.WellKnownTypes'
35
+ _globals['_DURATION']._serialized_start=51
36
+ _globals['_DURATION']._serialized_end=109
37
+ # @@protoc_insertion_point(module_scope)
.venv/lib/python3.11/site-packages/google/protobuf/empty_pb2.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Generated by the protocol buffer compiler. DO NOT EDIT!
3
+ # NO CHECKED-IN PROTOBUF GENCODE
4
+ # source: google/protobuf/empty.proto
5
+ # Protobuf Python Version: 5.29.3
6
+ """Generated protocol buffer code."""
7
+ from google.protobuf import descriptor as _descriptor
8
+ from google.protobuf import descriptor_pool as _descriptor_pool
9
+ from google.protobuf import runtime_version as _runtime_version
10
+ from google.protobuf import symbol_database as _symbol_database
11
+ from google.protobuf.internal import builder as _builder
12
+ _runtime_version.ValidateProtobufRuntimeVersion(
13
+ _runtime_version.Domain.PUBLIC,
14
+ 5,
15
+ 29,
16
+ 3,
17
+ '',
18
+ 'google/protobuf/empty.proto'
19
+ )
20
+ # @@protoc_insertion_point(imports)
21
+
22
+ _sym_db = _symbol_database.Default()
23
+
24
+
25
+
26
+
27
+ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1bgoogle/protobuf/empty.proto\x12\x0fgoogle.protobuf\"\x07\n\x05\x45mptyB}\n\x13\x63om.google.protobufB\nEmptyProtoP\x01Z.google.golang.org/protobuf/types/known/emptypb\xf8\x01\x01\xa2\x02\x03GPB\xaa\x02\x1eGoogle.Protobuf.WellKnownTypesb\x06proto3')
28
+
29
+ _globals = globals()
30
+ _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
31
+ _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'google.protobuf.empty_pb2', _globals)
32
+ if not _descriptor._USE_C_DESCRIPTORS:
33
+ _globals['DESCRIPTOR']._loaded_options = None
34
+ _globals['DESCRIPTOR']._serialized_options = b'\n\023com.google.protobufB\nEmptyProtoP\001Z.google.golang.org/protobuf/types/known/emptypb\370\001\001\242\002\003GPB\252\002\036Google.Protobuf.WellKnownTypes'
35
+ _globals['_EMPTY']._serialized_start=48
36
+ _globals['_EMPTY']._serialized_end=55
37
+ # @@protoc_insertion_point(module_scope)
.venv/lib/python3.11/site-packages/google/protobuf/field_mask_pb2.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Generated by the protocol buffer compiler. DO NOT EDIT!
3
+ # NO CHECKED-IN PROTOBUF GENCODE
4
+ # source: google/protobuf/field_mask.proto
5
+ # Protobuf Python Version: 5.29.3
6
+ """Generated protocol buffer code."""
7
+ from google.protobuf import descriptor as _descriptor
8
+ from google.protobuf import descriptor_pool as _descriptor_pool
9
+ from google.protobuf import runtime_version as _runtime_version
10
+ from google.protobuf import symbol_database as _symbol_database
11
+ from google.protobuf.internal import builder as _builder
12
+ _runtime_version.ValidateProtobufRuntimeVersion(
13
+ _runtime_version.Domain.PUBLIC,
14
+ 5,
15
+ 29,
16
+ 3,
17
+ '',
18
+ 'google/protobuf/field_mask.proto'
19
+ )
20
+ # @@protoc_insertion_point(imports)
21
+
22
+ _sym_db = _symbol_database.Default()
23
+
24
+
25
+
26
+
27
+ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n google/protobuf/field_mask.proto\x12\x0fgoogle.protobuf\"!\n\tFieldMask\x12\x14\n\x05paths\x18\x01 \x03(\tR\x05pathsB\x85\x01\n\x13\x63om.google.protobufB\x0e\x46ieldMaskProtoP\x01Z2google.golang.org/protobuf/types/known/fieldmaskpb\xf8\x01\x01\xa2\x02\x03GPB\xaa\x02\x1eGoogle.Protobuf.WellKnownTypesb\x06proto3')
28
+
29
+ _globals = globals()
30
+ _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
31
+ _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'google.protobuf.field_mask_pb2', _globals)
32
+ if not _descriptor._USE_C_DESCRIPTORS:
33
+ _globals['DESCRIPTOR']._loaded_options = None
34
+ _globals['DESCRIPTOR']._serialized_options = b'\n\023com.google.protobufB\016FieldMaskProtoP\001Z2google.golang.org/protobuf/types/known/fieldmaskpb\370\001\001\242\002\003GPB\252\002\036Google.Protobuf.WellKnownTypes'
35
+ _globals['_FIELDMASK']._serialized_start=53
36
+ _globals['_FIELDMASK']._serialized_end=86
37
+ # @@protoc_insertion_point(module_scope)
.venv/lib/python3.11/site-packages/google/protobuf/internal/_parameterized.py ADDED
@@ -0,0 +1,420 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #! /usr/bin/env python
2
+ #
3
+ # Protocol Buffers - Google's data interchange format
4
+ # Copyright 2008 Google Inc. All rights reserved.
5
+ #
6
+ # Use of this source code is governed by a BSD-style
7
+ # license that can be found in the LICENSE file or at
8
+ # https://developers.google.com/open-source/licenses/bsd
9
+
10
+ """Adds support for parameterized tests to Python's unittest TestCase class.
11
+
12
+ A parameterized test is a method in a test case that is invoked with different
13
+ argument tuples.
14
+
15
+ A simple example:
16
+
17
+ class AdditionExample(_parameterized.TestCase):
18
+ @_parameterized.parameters(
19
+ (1, 2, 3),
20
+ (4, 5, 9),
21
+ (1, 1, 3))
22
+ def testAddition(self, op1, op2, result):
23
+ self.assertEqual(result, op1 + op2)
24
+
25
+
26
+ Each invocation is a separate test case and properly isolated just
27
+ like a normal test method, with its own setUp/tearDown cycle. In the
28
+ example above, there are three separate testcases, one of which will
29
+ fail due to an assertion error (1 + 1 != 3).
30
+
31
+ Parameters for individual test cases can be tuples (with positional parameters)
32
+ or dictionaries (with named parameters):
33
+
34
+ class AdditionExample(_parameterized.TestCase):
35
+ @_parameterized.parameters(
36
+ {'op1': 1, 'op2': 2, 'result': 3},
37
+ {'op1': 4, 'op2': 5, 'result': 9},
38
+ )
39
+ def testAddition(self, op1, op2, result):
40
+ self.assertEqual(result, op1 + op2)
41
+
42
+ If a parameterized test fails, the error message will show the
43
+ original test name (which is modified internally) and the arguments
44
+ for the specific invocation, which are part of the string returned by
45
+ the shortDescription() method on test cases.
46
+
47
+ The id method of the test, used internally by the unittest framework,
48
+ is also modified to show the arguments. To make sure that test names
49
+ stay the same across several invocations, object representations like
50
+
51
+ >>> class Foo(object):
52
+ ... pass
53
+ >>> repr(Foo())
54
+ '<__main__.Foo object at 0x23d8610>'
55
+
56
+ are turned into '<__main__.Foo>'. For even more descriptive names,
57
+ especially in test logs, you can use the named_parameters decorator. In
58
+ this case, only tuples are supported, and the first parameters has to
59
+ be a string (or an object that returns an apt name when converted via
60
+ str()):
61
+
62
+ class NamedExample(_parameterized.TestCase):
63
+ @_parameterized.named_parameters(
64
+ ('Normal', 'aa', 'aaa', True),
65
+ ('EmptyPrefix', '', 'abc', True),
66
+ ('BothEmpty', '', '', True))
67
+ def testStartsWith(self, prefix, string, result):
68
+ self.assertEqual(result, strings.startswith(prefix))
69
+
70
+ Named tests also have the benefit that they can be run individually
71
+ from the command line:
72
+
73
+ $ testmodule.py NamedExample.testStartsWithNormal
74
+ .
75
+ --------------------------------------------------------------------
76
+ Ran 1 test in 0.000s
77
+
78
+ OK
79
+
80
+ Parameterized Classes
81
+ =====================
82
+ If invocation arguments are shared across test methods in a single
83
+ TestCase class, instead of decorating all test methods
84
+ individually, the class itself can be decorated:
85
+
86
+ @_parameterized.parameters(
87
+ (1, 2, 3)
88
+ (4, 5, 9))
89
+ class ArithmeticTest(_parameterized.TestCase):
90
+ def testAdd(self, arg1, arg2, result):
91
+ self.assertEqual(arg1 + arg2, result)
92
+
93
+ def testSubtract(self, arg2, arg2, result):
94
+ self.assertEqual(result - arg1, arg2)
95
+
96
+ Inputs from Iterables
97
+ =====================
98
+ If parameters should be shared across several test cases, or are dynamically
99
+ created from other sources, a single non-tuple iterable can be passed into
100
+ the decorator. This iterable will be used to obtain the test cases:
101
+
102
+ class AdditionExample(_parameterized.TestCase):
103
+ @_parameterized.parameters(
104
+ c.op1, c.op2, c.result for c in testcases
105
+ )
106
+ def testAddition(self, op1, op2, result):
107
+ self.assertEqual(result, op1 + op2)
108
+
109
+
110
+ Single-Argument Test Methods
111
+ ============================
112
+ If a test method takes only one argument, the single argument does not need to
113
+ be wrapped into a tuple:
114
+
115
+ class NegativeNumberExample(_parameterized.TestCase):
116
+ @_parameterized.parameters(
117
+ -1, -3, -4, -5
118
+ )
119
+ def testIsNegative(self, arg):
120
+ self.assertTrue(IsNegative(arg))
121
+ """
122
+
123
+ __author__ = 'tmarek@google.com (Torsten Marek)'
124
+
125
+ import functools
126
+ import re
127
+ import types
128
+ import unittest
129
+ import uuid
130
+
131
+ try:
132
+ # Since python 3
133
+ import collections.abc as collections_abc
134
+ except ImportError:
135
+ # Won't work after python 3.8
136
+ import collections as collections_abc
137
+
138
+ ADDR_RE = re.compile(r'\<([a-zA-Z0-9_\-\.]+) object at 0x[a-fA-F0-9]+\>')
139
+ _SEPARATOR = uuid.uuid1().hex
140
+ _FIRST_ARG = object()
141
+ _ARGUMENT_REPR = object()
142
+
143
+
144
+ def _CleanRepr(obj):
145
+ return ADDR_RE.sub(r'<\1>', repr(obj))
146
+
147
+
148
+ # Helper function formerly from the unittest module, removed from it in
149
+ # Python 2.7.
150
+ def _StrClass(cls):
151
+ return '%s.%s' % (cls.__module__, cls.__name__)
152
+
153
+
154
+ def _NonStringIterable(obj):
155
+ return (isinstance(obj, collections_abc.Iterable) and
156
+ not isinstance(obj, str))
157
+
158
+
159
+ def _FormatParameterList(testcase_params):
160
+ if isinstance(testcase_params, collections_abc.Mapping):
161
+ return ', '.join('%s=%s' % (argname, _CleanRepr(value))
162
+ for argname, value in testcase_params.items())
163
+ elif _NonStringIterable(testcase_params):
164
+ return ', '.join(map(_CleanRepr, testcase_params))
165
+ else:
166
+ return _FormatParameterList((testcase_params,))
167
+
168
+
169
+ class _ParameterizedTestIter(object):
170
+ """Callable and iterable class for producing new test cases."""
171
+
172
+ def __init__(self, test_method, testcases, naming_type):
173
+ """Returns concrete test functions for a test and a list of parameters.
174
+
175
+ The naming_type is used to determine the name of the concrete
176
+ functions as reported by the unittest framework. If naming_type is
177
+ _FIRST_ARG, the testcases must be tuples, and the first element must
178
+ have a string representation that is a valid Python identifier.
179
+
180
+ Args:
181
+ test_method: The decorated test method.
182
+ testcases: (list of tuple/dict) A list of parameter
183
+ tuples/dicts for individual test invocations.
184
+ naming_type: The test naming type, either _NAMED or _ARGUMENT_REPR.
185
+ """
186
+ self._test_method = test_method
187
+ self.testcases = testcases
188
+ self._naming_type = naming_type
189
+
190
+ def __call__(self, *args, **kwargs):
191
+ raise RuntimeError('You appear to be running a parameterized test case '
192
+ 'without having inherited from parameterized.'
193
+ 'TestCase. This is bad because none of '
194
+ 'your test cases are actually being run.')
195
+
196
+ def __iter__(self):
197
+ test_method = self._test_method
198
+ naming_type = self._naming_type
199
+
200
+ def MakeBoundParamTest(testcase_params):
201
+ @functools.wraps(test_method)
202
+ def BoundParamTest(self):
203
+ if isinstance(testcase_params, collections_abc.Mapping):
204
+ test_method(self, **testcase_params)
205
+ elif _NonStringIterable(testcase_params):
206
+ test_method(self, *testcase_params)
207
+ else:
208
+ test_method(self, testcase_params)
209
+
210
+ if naming_type is _FIRST_ARG:
211
+ # Signal the metaclass that the name of the test function is unique
212
+ # and descriptive.
213
+ BoundParamTest.__x_use_name__ = True
214
+ BoundParamTest.__name__ += str(testcase_params[0])
215
+ testcase_params = testcase_params[1:]
216
+ elif naming_type is _ARGUMENT_REPR:
217
+ # __x_extra_id__ is used to pass naming information to the __new__
218
+ # method of TestGeneratorMetaclass.
219
+ # The metaclass will make sure to create a unique, but nondescriptive
220
+ # name for this test.
221
+ BoundParamTest.__x_extra_id__ = '(%s)' % (
222
+ _FormatParameterList(testcase_params),)
223
+ else:
224
+ raise RuntimeError('%s is not a valid naming type.' % (naming_type,))
225
+
226
+ BoundParamTest.__doc__ = '%s(%s)' % (
227
+ BoundParamTest.__name__, _FormatParameterList(testcase_params))
228
+ if test_method.__doc__:
229
+ BoundParamTest.__doc__ += '\n%s' % (test_method.__doc__,)
230
+ return BoundParamTest
231
+ return (MakeBoundParamTest(c) for c in self.testcases)
232
+
233
+
234
+ def _IsSingletonList(testcases):
235
+ """True iff testcases contains only a single non-tuple element."""
236
+ return len(testcases) == 1 and not isinstance(testcases[0], tuple)
237
+
238
+
239
+ def _ModifyClass(class_object, testcases, naming_type):
240
+ assert not getattr(class_object, '_id_suffix', None), (
241
+ 'Cannot add parameters to %s,'
242
+ ' which already has parameterized methods.' % (class_object,))
243
+ class_object._id_suffix = id_suffix = {}
244
+ # We change the size of __dict__ while we iterate over it,
245
+ # which Python 3.x will complain about, so use copy().
246
+ for name, obj in class_object.__dict__.copy().items():
247
+ if (name.startswith(unittest.TestLoader.testMethodPrefix)
248
+ and isinstance(obj, types.FunctionType)):
249
+ delattr(class_object, name)
250
+ methods = {}
251
+ _UpdateClassDictForParamTestCase(
252
+ methods, id_suffix, name,
253
+ _ParameterizedTestIter(obj, testcases, naming_type))
254
+ for name, meth in methods.items():
255
+ setattr(class_object, name, meth)
256
+
257
+
258
+ def _ParameterDecorator(naming_type, testcases):
259
+ """Implementation of the parameterization decorators.
260
+
261
+ Args:
262
+ naming_type: The naming type.
263
+ testcases: Testcase parameters.
264
+
265
+ Returns:
266
+ A function for modifying the decorated object.
267
+ """
268
+ def _Apply(obj):
269
+ if isinstance(obj, type):
270
+ _ModifyClass(
271
+ obj,
272
+ list(testcases) if not isinstance(testcases, collections_abc.Sequence)
273
+ else testcases,
274
+ naming_type)
275
+ return obj
276
+ else:
277
+ return _ParameterizedTestIter(obj, testcases, naming_type)
278
+
279
+ if _IsSingletonList(testcases):
280
+ assert _NonStringIterable(testcases[0]), (
281
+ 'Single parameter argument must be a non-string iterable')
282
+ testcases = testcases[0]
283
+
284
+ return _Apply
285
+
286
+
287
+ def parameters(*testcases): # pylint: disable=invalid-name
288
+ """A decorator for creating parameterized tests.
289
+
290
+ See the module docstring for a usage example.
291
+ Args:
292
+ *testcases: Parameters for the decorated method, either a single
293
+ iterable, or a list of tuples/dicts/objects (for tests
294
+ with only one argument).
295
+
296
+ Returns:
297
+ A test generator to be handled by TestGeneratorMetaclass.
298
+ """
299
+ return _ParameterDecorator(_ARGUMENT_REPR, testcases)
300
+
301
+
302
+ def named_parameters(*testcases): # pylint: disable=invalid-name
303
+ """A decorator for creating parameterized tests.
304
+
305
+ See the module docstring for a usage example. The first element of
306
+ each parameter tuple should be a string and will be appended to the
307
+ name of the test method.
308
+
309
+ Args:
310
+ *testcases: Parameters for the decorated method, either a single
311
+ iterable, or a list of tuples.
312
+
313
+ Returns:
314
+ A test generator to be handled by TestGeneratorMetaclass.
315
+ """
316
+ return _ParameterDecorator(_FIRST_ARG, testcases)
317
+
318
+
319
+ class TestGeneratorMetaclass(type):
320
+ """Metaclass for test cases with test generators.
321
+
322
+ A test generator is an iterable in a testcase that produces callables. These
323
+ callables must be single-argument methods. These methods are injected into
324
+ the class namespace and the original iterable is removed. If the name of the
325
+ iterable conforms to the test pattern, the injected methods will be picked
326
+ up as tests by the unittest framework.
327
+
328
+ In general, it is supposed to be used in conjunction with the
329
+ parameters decorator.
330
+ """
331
+
332
+ def __new__(mcs, class_name, bases, dct):
333
+ dct['_id_suffix'] = id_suffix = {}
334
+ for name, obj in dct.copy().items():
335
+ if (name.startswith(unittest.TestLoader.testMethodPrefix) and
336
+ _NonStringIterable(obj)):
337
+ iterator = iter(obj)
338
+ dct.pop(name)
339
+ _UpdateClassDictForParamTestCase(dct, id_suffix, name, iterator)
340
+
341
+ return type.__new__(mcs, class_name, bases, dct)
342
+
343
+
344
+ def _UpdateClassDictForParamTestCase(dct, id_suffix, name, iterator):
345
+ """Adds individual test cases to a dictionary.
346
+
347
+ Args:
348
+ dct: The target dictionary.
349
+ id_suffix: The dictionary for mapping names to test IDs.
350
+ name: The original name of the test case.
351
+ iterator: The iterator generating the individual test cases.
352
+ """
353
+ for idx, func in enumerate(iterator):
354
+ assert callable(func), 'Test generators must yield callables, got %r' % (
355
+ func,)
356
+ if getattr(func, '__x_use_name__', False):
357
+ new_name = func.__name__
358
+ else:
359
+ new_name = '%s%s%d' % (name, _SEPARATOR, idx)
360
+ assert new_name not in dct, (
361
+ 'Name of parameterized test case "%s" not unique' % (new_name,))
362
+ dct[new_name] = func
363
+ id_suffix[new_name] = getattr(func, '__x_extra_id__', '')
364
+
365
+
366
+ class TestCase(unittest.TestCase, metaclass=TestGeneratorMetaclass):
367
+ """Base class for test cases using the parameters decorator."""
368
+
369
+ def _OriginalName(self):
370
+ return self._testMethodName.split(_SEPARATOR)[0]
371
+
372
+ def __str__(self):
373
+ return '%s (%s)' % (self._OriginalName(), _StrClass(self.__class__))
374
+
375
+ def id(self): # pylint: disable=invalid-name
376
+ """Returns the descriptive ID of the test.
377
+
378
+ This is used internally by the unittesting framework to get a name
379
+ for the test to be used in reports.
380
+
381
+ Returns:
382
+ The test id.
383
+ """
384
+ return '%s.%s%s' % (_StrClass(self.__class__),
385
+ self._OriginalName(),
386
+ self._id_suffix.get(self._testMethodName, ''))
387
+
388
+
389
+ def CoopTestCase(other_base_class):
390
+ """Returns a new base class with a cooperative metaclass base.
391
+
392
+ This enables the TestCase to be used in combination
393
+ with other base classes that have custom metaclasses, such as
394
+ mox.MoxTestBase.
395
+
396
+ Only works with metaclasses that do not override type.__new__.
397
+
398
+ Example:
399
+
400
+ import google3
401
+ import mox
402
+
403
+ from google.protobuf.internal import _parameterized
404
+
405
+ class ExampleTest(parameterized.CoopTestCase(mox.MoxTestBase)):
406
+ ...
407
+
408
+ Args:
409
+ other_base_class: (class) A test case base class.
410
+
411
+ Returns:
412
+ A new class object.
413
+ """
414
+ metaclass = type(
415
+ 'CoopMetaclass',
416
+ (other_base_class.__metaclass__,
417
+ TestGeneratorMetaclass), {})
418
+ return metaclass(
419
+ 'CoopTestCase',
420
+ (other_base_class, TestCase), {})
.venv/lib/python3.11/site-packages/google/protobuf/internal/containers.py ADDED
@@ -0,0 +1,677 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Protocol Buffers - Google's data interchange format
2
+ # Copyright 2008 Google Inc. All rights reserved.
3
+ #
4
+ # Use of this source code is governed by a BSD-style
5
+ # license that can be found in the LICENSE file or at
6
+ # https://developers.google.com/open-source/licenses/bsd
7
+
8
+ """Contains container classes to represent different protocol buffer types.
9
+
10
+ This file defines container classes which represent categories of protocol
11
+ buffer field types which need extra maintenance. Currently these categories
12
+ are:
13
+
14
+ - Repeated scalar fields - These are all repeated fields which aren't
15
+ composite (e.g. they are of simple types like int32, string, etc).
16
+ - Repeated composite fields - Repeated fields which are composite. This
17
+ includes groups and nested messages.
18
+ """
19
+
20
+ import collections.abc
21
+ import copy
22
+ import pickle
23
+ from typing import (
24
+ Any,
25
+ Iterable,
26
+ Iterator,
27
+ List,
28
+ MutableMapping,
29
+ MutableSequence,
30
+ NoReturn,
31
+ Optional,
32
+ Sequence,
33
+ TypeVar,
34
+ Union,
35
+ overload,
36
+ )
37
+
38
+
39
+ _T = TypeVar('_T')
40
+ _K = TypeVar('_K')
41
+ _V = TypeVar('_V')
42
+
43
+
44
+ class BaseContainer(Sequence[_T]):
45
+ """Base container class."""
46
+
47
+ # Minimizes memory usage and disallows assignment to other attributes.
48
+ __slots__ = ['_message_listener', '_values']
49
+
50
+ def __init__(self, message_listener: Any) -> None:
51
+ """
52
+ Args:
53
+ message_listener: A MessageListener implementation.
54
+ The RepeatedScalarFieldContainer will call this object's
55
+ Modified() method when it is modified.
56
+ """
57
+ self._message_listener = message_listener
58
+ self._values = []
59
+
60
+ @overload
61
+ def __getitem__(self, key: int) -> _T:
62
+ ...
63
+
64
+ @overload
65
+ def __getitem__(self, key: slice) -> List[_T]:
66
+ ...
67
+
68
+ def __getitem__(self, key):
69
+ """Retrieves item by the specified key."""
70
+ return self._values[key]
71
+
72
+ def __len__(self) -> int:
73
+ """Returns the number of elements in the container."""
74
+ return len(self._values)
75
+
76
+ def __ne__(self, other: Any) -> bool:
77
+ """Checks if another instance isn't equal to this one."""
78
+ # The concrete classes should define __eq__.
79
+ return not self == other
80
+
81
+ __hash__ = None
82
+
83
+ def __repr__(self) -> str:
84
+ return repr(self._values)
85
+
86
+ def sort(self, *args, **kwargs) -> None:
87
+ # Continue to support the old sort_function keyword argument.
88
+ # This is expected to be a rare occurrence, so use LBYL to avoid
89
+ # the overhead of actually catching KeyError.
90
+ if 'sort_function' in kwargs:
91
+ kwargs['cmp'] = kwargs.pop('sort_function')
92
+ self._values.sort(*args, **kwargs)
93
+
94
+ def reverse(self) -> None:
95
+ self._values.reverse()
96
+
97
+
98
+ # TODO: Remove this. BaseContainer does *not* conform to
99
+ # MutableSequence, only its subclasses do.
100
+ collections.abc.MutableSequence.register(BaseContainer)
101
+
102
+
103
+ class RepeatedScalarFieldContainer(BaseContainer[_T], MutableSequence[_T]):
104
+ """Simple, type-checked, list-like container for holding repeated scalars."""
105
+
106
+ # Disallows assignment to other attributes.
107
+ __slots__ = ['_type_checker']
108
+
109
+ def __init__(
110
+ self,
111
+ message_listener: Any,
112
+ type_checker: Any,
113
+ ) -> None:
114
+ """Args:
115
+
116
+ message_listener: A MessageListener implementation. The
117
+ RepeatedScalarFieldContainer will call this object's Modified() method
118
+ when it is modified.
119
+ type_checker: A type_checkers.ValueChecker instance to run on elements
120
+ inserted into this container.
121
+ """
122
+ super().__init__(message_listener)
123
+ self._type_checker = type_checker
124
+
125
+ def append(self, value: _T) -> None:
126
+ """Appends an item to the list. Similar to list.append()."""
127
+ self._values.append(self._type_checker.CheckValue(value))
128
+ if not self._message_listener.dirty:
129
+ self._message_listener.Modified()
130
+
131
+ def insert(self, key: int, value: _T) -> None:
132
+ """Inserts the item at the specified position. Similar to list.insert()."""
133
+ self._values.insert(key, self._type_checker.CheckValue(value))
134
+ if not self._message_listener.dirty:
135
+ self._message_listener.Modified()
136
+
137
+ def extend(self, elem_seq: Iterable[_T]) -> None:
138
+ """Extends by appending the given iterable. Similar to list.extend()."""
139
+ elem_seq_iter = iter(elem_seq)
140
+ new_values = [self._type_checker.CheckValue(elem) for elem in elem_seq_iter]
141
+ if new_values:
142
+ self._values.extend(new_values)
143
+ self._message_listener.Modified()
144
+
145
+ def MergeFrom(
146
+ self,
147
+ other: Union['RepeatedScalarFieldContainer[_T]', Iterable[_T]],
148
+ ) -> None:
149
+ """Appends the contents of another repeated field of the same type to this
150
+ one. We do not check the types of the individual fields.
151
+ """
152
+ self._values.extend(other)
153
+ self._message_listener.Modified()
154
+
155
+ def remove(self, elem: _T):
156
+ """Removes an item from the list. Similar to list.remove()."""
157
+ self._values.remove(elem)
158
+ self._message_listener.Modified()
159
+
160
+ def pop(self, key: Optional[int] = -1) -> _T:
161
+ """Removes and returns an item at a given index. Similar to list.pop()."""
162
+ value = self._values[key]
163
+ self.__delitem__(key)
164
+ return value
165
+
166
+ @overload
167
+ def __setitem__(self, key: int, value: _T) -> None:
168
+ ...
169
+
170
+ @overload
171
+ def __setitem__(self, key: slice, value: Iterable[_T]) -> None:
172
+ ...
173
+
174
+ def __setitem__(self, key, value) -> None:
175
+ """Sets the item on the specified position."""
176
+ if isinstance(key, slice):
177
+ if key.step is not None:
178
+ raise ValueError('Extended slices not supported')
179
+ self._values[key] = map(self._type_checker.CheckValue, value)
180
+ self._message_listener.Modified()
181
+ else:
182
+ self._values[key] = self._type_checker.CheckValue(value)
183
+ self._message_listener.Modified()
184
+
185
+ def __delitem__(self, key: Union[int, slice]) -> None:
186
+ """Deletes the item at the specified position."""
187
+ del self._values[key]
188
+ self._message_listener.Modified()
189
+
190
+ def __eq__(self, other: Any) -> bool:
191
+ """Compares the current instance with another one."""
192
+ if self is other:
193
+ return True
194
+ # Special case for the same type which should be common and fast.
195
+ if isinstance(other, self.__class__):
196
+ return other._values == self._values
197
+ # We are presumably comparing against some other sequence type.
198
+ return other == self._values
199
+
200
+ def __deepcopy__(
201
+ self,
202
+ unused_memo: Any = None,
203
+ ) -> 'RepeatedScalarFieldContainer[_T]':
204
+ clone = RepeatedScalarFieldContainer(
205
+ copy.deepcopy(self._message_listener), self._type_checker)
206
+ clone.MergeFrom(self)
207
+ return clone
208
+
209
+ def __reduce__(self, **kwargs) -> NoReturn:
210
+ raise pickle.PickleError(
211
+ "Can't pickle repeated scalar fields, convert to list first")
212
+
213
+
214
+ # TODO: Constrain T to be a subtype of Message.
215
+ class RepeatedCompositeFieldContainer(BaseContainer[_T], MutableSequence[_T]):
216
+ """Simple, list-like container for holding repeated composite fields."""
217
+
218
+ # Disallows assignment to other attributes.
219
+ __slots__ = ['_message_descriptor']
220
+
221
+ def __init__(self, message_listener: Any, message_descriptor: Any) -> None:
222
+ """
223
+ Note that we pass in a descriptor instead of the generated directly,
224
+ since at the time we construct a _RepeatedCompositeFieldContainer we
225
+ haven't yet necessarily initialized the type that will be contained in the
226
+ container.
227
+
228
+ Args:
229
+ message_listener: A MessageListener implementation.
230
+ The RepeatedCompositeFieldContainer will call this object's
231
+ Modified() method when it is modified.
232
+ message_descriptor: A Descriptor instance describing the protocol type
233
+ that should be present in this container. We'll use the
234
+ _concrete_class field of this descriptor when the client calls add().
235
+ """
236
+ super().__init__(message_listener)
237
+ self._message_descriptor = message_descriptor
238
+
239
+ def add(self, **kwargs: Any) -> _T:
240
+ """Adds a new element at the end of the list and returns it. Keyword
241
+ arguments may be used to initialize the element.
242
+ """
243
+ new_element = self._message_descriptor._concrete_class(**kwargs)
244
+ new_element._SetListener(self._message_listener)
245
+ self._values.append(new_element)
246
+ if not self._message_listener.dirty:
247
+ self._message_listener.Modified()
248
+ return new_element
249
+
250
+ def append(self, value: _T) -> None:
251
+ """Appends one element by copying the message."""
252
+ new_element = self._message_descriptor._concrete_class()
253
+ new_element._SetListener(self._message_listener)
254
+ new_element.CopyFrom(value)
255
+ self._values.append(new_element)
256
+ if not self._message_listener.dirty:
257
+ self._message_listener.Modified()
258
+
259
+ def insert(self, key: int, value: _T) -> None:
260
+ """Inserts the item at the specified position by copying."""
261
+ new_element = self._message_descriptor._concrete_class()
262
+ new_element._SetListener(self._message_listener)
263
+ new_element.CopyFrom(value)
264
+ self._values.insert(key, new_element)
265
+ if not self._message_listener.dirty:
266
+ self._message_listener.Modified()
267
+
268
+ def extend(self, elem_seq: Iterable[_T]) -> None:
269
+ """Extends by appending the given sequence of elements of the same type
270
+
271
+ as this one, copying each individual message.
272
+ """
273
+ message_class = self._message_descriptor._concrete_class
274
+ listener = self._message_listener
275
+ values = self._values
276
+ for message in elem_seq:
277
+ new_element = message_class()
278
+ new_element._SetListener(listener)
279
+ new_element.MergeFrom(message)
280
+ values.append(new_element)
281
+ listener.Modified()
282
+
283
+ def MergeFrom(
284
+ self,
285
+ other: Union['RepeatedCompositeFieldContainer[_T]', Iterable[_T]],
286
+ ) -> None:
287
+ """Appends the contents of another repeated field of the same type to this
288
+ one, copying each individual message.
289
+ """
290
+ self.extend(other)
291
+
292
+ def remove(self, elem: _T) -> None:
293
+ """Removes an item from the list. Similar to list.remove()."""
294
+ self._values.remove(elem)
295
+ self._message_listener.Modified()
296
+
297
+ def pop(self, key: Optional[int] = -1) -> _T:
298
+ """Removes and returns an item at a given index. Similar to list.pop()."""
299
+ value = self._values[key]
300
+ self.__delitem__(key)
301
+ return value
302
+
303
+ @overload
304
+ def __setitem__(self, key: int, value: _T) -> None:
305
+ ...
306
+
307
+ @overload
308
+ def __setitem__(self, key: slice, value: Iterable[_T]) -> None:
309
+ ...
310
+
311
+ def __setitem__(self, key, value):
312
+ # This method is implemented to make RepeatedCompositeFieldContainer
313
+ # structurally compatible with typing.MutableSequence. It is
314
+ # otherwise unsupported and will always raise an error.
315
+ raise TypeError(
316
+ f'{self.__class__.__name__} object does not support item assignment')
317
+
318
+ def __delitem__(self, key: Union[int, slice]) -> None:
319
+ """Deletes the item at the specified position."""
320
+ del self._values[key]
321
+ self._message_listener.Modified()
322
+
323
+ def __eq__(self, other: Any) -> bool:
324
+ """Compares the current instance with another one."""
325
+ if self is other:
326
+ return True
327
+ if not isinstance(other, self.__class__):
328
+ raise TypeError('Can only compare repeated composite fields against '
329
+ 'other repeated composite fields.')
330
+ return self._values == other._values
331
+
332
+
333
+ class ScalarMap(MutableMapping[_K, _V]):
334
+ """Simple, type-checked, dict-like container for holding repeated scalars."""
335
+
336
+ # Disallows assignment to other attributes.
337
+ __slots__ = ['_key_checker', '_value_checker', '_values', '_message_listener',
338
+ '_entry_descriptor']
339
+
340
+ def __init__(
341
+ self,
342
+ message_listener: Any,
343
+ key_checker: Any,
344
+ value_checker: Any,
345
+ entry_descriptor: Any,
346
+ ) -> None:
347
+ """
348
+ Args:
349
+ message_listener: A MessageListener implementation.
350
+ The ScalarMap will call this object's Modified() method when it
351
+ is modified.
352
+ key_checker: A type_checkers.ValueChecker instance to run on keys
353
+ inserted into this container.
354
+ value_checker: A type_checkers.ValueChecker instance to run on values
355
+ inserted into this container.
356
+ entry_descriptor: The MessageDescriptor of a map entry: key and value.
357
+ """
358
+ self._message_listener = message_listener
359
+ self._key_checker = key_checker
360
+ self._value_checker = value_checker
361
+ self._entry_descriptor = entry_descriptor
362
+ self._values = {}
363
+
364
+ def __getitem__(self, key: _K) -> _V:
365
+ try:
366
+ return self._values[key]
367
+ except KeyError:
368
+ key = self._key_checker.CheckValue(key)
369
+ val = self._value_checker.DefaultValue()
370
+ self._values[key] = val
371
+ return val
372
+
373
+ def __contains__(self, item: _K) -> bool:
374
+ # We check the key's type to match the strong-typing flavor of the API.
375
+ # Also this makes it easier to match the behavior of the C++ implementation.
376
+ self._key_checker.CheckValue(item)
377
+ return item in self._values
378
+
379
+ @overload
380
+ def get(self, key: _K) -> Optional[_V]:
381
+ ...
382
+
383
+ @overload
384
+ def get(self, key: _K, default: _T) -> Union[_V, _T]:
385
+ ...
386
+
387
+ # We need to override this explicitly, because our defaultdict-like behavior
388
+ # will make the default implementation (from our base class) always insert
389
+ # the key.
390
+ def get(self, key, default=None):
391
+ if key in self:
392
+ return self[key]
393
+ else:
394
+ return default
395
+
396
+ def __setitem__(self, key: _K, value: _V) -> _T:
397
+ checked_key = self._key_checker.CheckValue(key)
398
+ checked_value = self._value_checker.CheckValue(value)
399
+ self._values[checked_key] = checked_value
400
+ self._message_listener.Modified()
401
+
402
+ def __delitem__(self, key: _K) -> None:
403
+ del self._values[key]
404
+ self._message_listener.Modified()
405
+
406
+ def __len__(self) -> int:
407
+ return len(self._values)
408
+
409
+ def __iter__(self) -> Iterator[_K]:
410
+ return iter(self._values)
411
+
412
+ def __repr__(self) -> str:
413
+ return repr(self._values)
414
+
415
+ def MergeFrom(self, other: 'ScalarMap[_K, _V]') -> None:
416
+ self._values.update(other._values)
417
+ self._message_listener.Modified()
418
+
419
+ def InvalidateIterators(self) -> None:
420
+ # It appears that the only way to reliably invalidate iterators to
421
+ # self._values is to ensure that its size changes.
422
+ original = self._values
423
+ self._values = original.copy()
424
+ original[None] = None
425
+
426
+ # This is defined in the abstract base, but we can do it much more cheaply.
427
+ def clear(self) -> None:
428
+ self._values.clear()
429
+ self._message_listener.Modified()
430
+
431
+ def GetEntryClass(self) -> Any:
432
+ return self._entry_descriptor._concrete_class
433
+
434
+
435
+ class MessageMap(MutableMapping[_K, _V]):
436
+ """Simple, type-checked, dict-like container for with submessage values."""
437
+
438
+ # Disallows assignment to other attributes.
439
+ __slots__ = ['_key_checker', '_values', '_message_listener',
440
+ '_message_descriptor', '_entry_descriptor']
441
+
442
+ def __init__(
443
+ self,
444
+ message_listener: Any,
445
+ message_descriptor: Any,
446
+ key_checker: Any,
447
+ entry_descriptor: Any,
448
+ ) -> None:
449
+ """
450
+ Args:
451
+ message_listener: A MessageListener implementation.
452
+ The ScalarMap will call this object's Modified() method when it
453
+ is modified.
454
+ key_checker: A type_checkers.ValueChecker instance to run on keys
455
+ inserted into this container.
456
+ value_checker: A type_checkers.ValueChecker instance to run on values
457
+ inserted into this container.
458
+ entry_descriptor: The MessageDescriptor of a map entry: key and value.
459
+ """
460
+ self._message_listener = message_listener
461
+ self._message_descriptor = message_descriptor
462
+ self._key_checker = key_checker
463
+ self._entry_descriptor = entry_descriptor
464
+ self._values = {}
465
+
466
+ def __getitem__(self, key: _K) -> _V:
467
+ key = self._key_checker.CheckValue(key)
468
+ try:
469
+ return self._values[key]
470
+ except KeyError:
471
+ new_element = self._message_descriptor._concrete_class()
472
+ new_element._SetListener(self._message_listener)
473
+ self._values[key] = new_element
474
+ self._message_listener.Modified()
475
+ return new_element
476
+
477
+ def get_or_create(self, key: _K) -> _V:
478
+ """get_or_create() is an alias for getitem (ie. map[key]).
479
+
480
+ Args:
481
+ key: The key to get or create in the map.
482
+
483
+ This is useful in cases where you want to be explicit that the call is
484
+ mutating the map. This can avoid lint errors for statements like this
485
+ that otherwise would appear to be pointless statements:
486
+
487
+ msg.my_map[key]
488
+ """
489
+ return self[key]
490
+
491
+ @overload
492
+ def get(self, key: _K) -> Optional[_V]:
493
+ ...
494
+
495
+ @overload
496
+ def get(self, key: _K, default: _T) -> Union[_V, _T]:
497
+ ...
498
+
499
+ # We need to override this explicitly, because our defaultdict-like behavior
500
+ # will make the default implementation (from our base class) always insert
501
+ # the key.
502
+ def get(self, key, default=None):
503
+ if key in self:
504
+ return self[key]
505
+ else:
506
+ return default
507
+
508
+ def __contains__(self, item: _K) -> bool:
509
+ item = self._key_checker.CheckValue(item)
510
+ return item in self._values
511
+
512
+ def __setitem__(self, key: _K, value: _V) -> NoReturn:
513
+ raise ValueError('May not set values directly, call my_map[key].foo = 5')
514
+
515
+ def __delitem__(self, key: _K) -> None:
516
+ key = self._key_checker.CheckValue(key)
517
+ del self._values[key]
518
+ self._message_listener.Modified()
519
+
520
+ def __len__(self) -> int:
521
+ return len(self._values)
522
+
523
+ def __iter__(self) -> Iterator[_K]:
524
+ return iter(self._values)
525
+
526
+ def __repr__(self) -> str:
527
+ return repr(self._values)
528
+
529
+ def MergeFrom(self, other: 'MessageMap[_K, _V]') -> None:
530
+ # pylint: disable=protected-access
531
+ for key in other._values:
532
+ # According to documentation: "When parsing from the wire or when merging,
533
+ # if there are duplicate map keys the last key seen is used".
534
+ if key in self:
535
+ del self[key]
536
+ self[key].CopyFrom(other[key])
537
+ # self._message_listener.Modified() not required here, because
538
+ # mutations to submessages already propagate.
539
+
540
+ def InvalidateIterators(self) -> None:
541
+ # It appears that the only way to reliably invalidate iterators to
542
+ # self._values is to ensure that its size changes.
543
+ original = self._values
544
+ self._values = original.copy()
545
+ original[None] = None
546
+
547
+ # This is defined in the abstract base, but we can do it much more cheaply.
548
+ def clear(self) -> None:
549
+ self._values.clear()
550
+ self._message_listener.Modified()
551
+
552
+ def GetEntryClass(self) -> Any:
553
+ return self._entry_descriptor._concrete_class
554
+
555
+
556
+ class _UnknownField:
557
+ """A parsed unknown field."""
558
+
559
+ # Disallows assignment to other attributes.
560
+ __slots__ = ['_field_number', '_wire_type', '_data']
561
+
562
+ def __init__(self, field_number, wire_type, data):
563
+ self._field_number = field_number
564
+ self._wire_type = wire_type
565
+ self._data = data
566
+ return
567
+
568
+ def __lt__(self, other):
569
+ # pylint: disable=protected-access
570
+ return self._field_number < other._field_number
571
+
572
+ def __eq__(self, other):
573
+ if self is other:
574
+ return True
575
+ # pylint: disable=protected-access
576
+ return (self._field_number == other._field_number and
577
+ self._wire_type == other._wire_type and
578
+ self._data == other._data)
579
+
580
+
581
+ class UnknownFieldRef: # pylint: disable=missing-class-docstring
582
+
583
+ def __init__(self, parent, index):
584
+ self._parent = parent
585
+ self._index = index
586
+
587
+ def _check_valid(self):
588
+ if not self._parent:
589
+ raise ValueError('UnknownField does not exist. '
590
+ 'The parent message might be cleared.')
591
+ if self._index >= len(self._parent):
592
+ raise ValueError('UnknownField does not exist. '
593
+ 'The parent message might be cleared.')
594
+
595
+ @property
596
+ def field_number(self):
597
+ self._check_valid()
598
+ # pylint: disable=protected-access
599
+ return self._parent._internal_get(self._index)._field_number
600
+
601
+ @property
602
+ def wire_type(self):
603
+ self._check_valid()
604
+ # pylint: disable=protected-access
605
+ return self._parent._internal_get(self._index)._wire_type
606
+
607
+ @property
608
+ def data(self):
609
+ self._check_valid()
610
+ # pylint: disable=protected-access
611
+ return self._parent._internal_get(self._index)._data
612
+
613
+
614
+ class UnknownFieldSet:
615
+ """UnknownField container"""
616
+
617
+ # Disallows assignment to other attributes.
618
+ __slots__ = ['_values']
619
+
620
+ def __init__(self):
621
+ self._values = []
622
+
623
+ def __getitem__(self, index):
624
+ if self._values is None:
625
+ raise ValueError('UnknownFields does not exist. '
626
+ 'The parent message might be cleared.')
627
+ size = len(self._values)
628
+ if index < 0:
629
+ index += size
630
+ if index < 0 or index >= size:
631
+ raise IndexError('index %d out of range'.index)
632
+
633
+ return UnknownFieldRef(self, index)
634
+
635
+ def _internal_get(self, index):
636
+ return self._values[index]
637
+
638
+ def __len__(self):
639
+ if self._values is None:
640
+ raise ValueError('UnknownFields does not exist. '
641
+ 'The parent message might be cleared.')
642
+ return len(self._values)
643
+
644
+ def _add(self, field_number, wire_type, data):
645
+ unknown_field = _UnknownField(field_number, wire_type, data)
646
+ self._values.append(unknown_field)
647
+ return unknown_field
648
+
649
+ def __iter__(self):
650
+ for i in range(len(self)):
651
+ yield UnknownFieldRef(self, i)
652
+
653
+ def _extend(self, other):
654
+ if other is None:
655
+ return
656
+ # pylint: disable=protected-access
657
+ self._values.extend(other._values)
658
+
659
+ def __eq__(self, other):
660
+ if self is other:
661
+ return True
662
+ # Sort unknown fields because their order shouldn't
663
+ # affect equality test.
664
+ values = list(self._values)
665
+ if other is None:
666
+ return not values
667
+ values.sort()
668
+ # pylint: disable=protected-access
669
+ other_values = sorted(other._values)
670
+ return values == other_values
671
+
672
+ def _clear(self):
673
+ for value in self._values:
674
+ # pylint: disable=protected-access
675
+ if isinstance(value._data, UnknownFieldSet):
676
+ value._data._clear() # pylint: disable=protected-access
677
+ self._values = None
.venv/lib/python3.11/site-packages/google/protobuf/internal/encoder.py ADDED
@@ -0,0 +1,806 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Protocol Buffers - Google's data interchange format
2
+ # Copyright 2008 Google Inc. All rights reserved.
3
+ #
4
+ # Use of this source code is governed by a BSD-style
5
+ # license that can be found in the LICENSE file or at
6
+ # https://developers.google.com/open-source/licenses/bsd
7
+
8
+ """Code for encoding protocol message primitives.
9
+
10
+ Contains the logic for encoding every logical protocol field type
11
+ into one of the 5 physical wire types.
12
+
13
+ This code is designed to push the Python interpreter's performance to the
14
+ limits.
15
+
16
+ The basic idea is that at startup time, for every field (i.e. every
17
+ FieldDescriptor) we construct two functions: a "sizer" and an "encoder". The
18
+ sizer takes a value of this field's type and computes its byte size. The
19
+ encoder takes a writer function and a value. It encodes the value into byte
20
+ strings and invokes the writer function to write those strings. Typically the
21
+ writer function is the write() method of a BytesIO.
22
+
23
+ We try to do as much work as possible when constructing the writer and the
24
+ sizer rather than when calling them. In particular:
25
+ * We copy any needed global functions to local variables, so that we do not need
26
+ to do costly global table lookups at runtime.
27
+ * Similarly, we try to do any attribute lookups at startup time if possible.
28
+ * Every field's tag is encoded to bytes at startup, since it can't change at
29
+ runtime.
30
+ * Whatever component of the field size we can compute at startup, we do.
31
+ * We *avoid* sharing code if doing so would make the code slower and not sharing
32
+ does not burden us too much. For example, encoders for repeated fields do
33
+ not just call the encoders for singular fields in a loop because this would
34
+ add an extra function call overhead for every loop iteration; instead, we
35
+ manually inline the single-value encoder into the loop.
36
+ * If a Python function lacks a return statement, Python actually generates
37
+ instructions to pop the result of the last statement off the stack, push
38
+ None onto the stack, and then return that. If we really don't care what
39
+ value is returned, then we can save two instructions by returning the
40
+ result of the last statement. It looks funny but it helps.
41
+ * We assume that type and bounds checking has happened at a higher level.
42
+ """
43
+
44
+ __author__ = 'kenton@google.com (Kenton Varda)'
45
+
46
+ import struct
47
+
48
+ from google.protobuf.internal import wire_format
49
+
50
+
51
+ # This will overflow and thus become IEEE-754 "infinity". We would use
52
+ # "float('inf')" but it doesn't work on Windows pre-Python-2.6.
53
+ _POS_INF = 1e10000
54
+ _NEG_INF = -_POS_INF
55
+
56
+
57
+ def _VarintSize(value):
58
+ """Compute the size of a varint value."""
59
+ if value <= 0x7f: return 1
60
+ if value <= 0x3fff: return 2
61
+ if value <= 0x1fffff: return 3
62
+ if value <= 0xfffffff: return 4
63
+ if value <= 0x7ffffffff: return 5
64
+ if value <= 0x3ffffffffff: return 6
65
+ if value <= 0x1ffffffffffff: return 7
66
+ if value <= 0xffffffffffffff: return 8
67
+ if value <= 0x7fffffffffffffff: return 9
68
+ return 10
69
+
70
+
71
+ def _SignedVarintSize(value):
72
+ """Compute the size of a signed varint value."""
73
+ if value < 0: return 10
74
+ if value <= 0x7f: return 1
75
+ if value <= 0x3fff: return 2
76
+ if value <= 0x1fffff: return 3
77
+ if value <= 0xfffffff: return 4
78
+ if value <= 0x7ffffffff: return 5
79
+ if value <= 0x3ffffffffff: return 6
80
+ if value <= 0x1ffffffffffff: return 7
81
+ if value <= 0xffffffffffffff: return 8
82
+ if value <= 0x7fffffffffffffff: return 9
83
+ return 10
84
+
85
+
86
+ def _TagSize(field_number):
87
+ """Returns the number of bytes required to serialize a tag with this field
88
+ number."""
89
+ # Just pass in type 0, since the type won't affect the tag+type size.
90
+ return _VarintSize(wire_format.PackTag(field_number, 0))
91
+
92
+
93
+ # --------------------------------------------------------------------
94
+ # In this section we define some generic sizers. Each of these functions
95
+ # takes parameters specific to a particular field type, e.g. int32 or fixed64.
96
+ # It returns another function which in turn takes parameters specific to a
97
+ # particular field, e.g. the field number and whether it is repeated or packed.
98
+ # Look at the next section to see how these are used.
99
+
100
+
101
+ def _SimpleSizer(compute_value_size):
102
+ """A sizer which uses the function compute_value_size to compute the size of
103
+ each value. Typically compute_value_size is _VarintSize."""
104
+
105
+ def SpecificSizer(field_number, is_repeated, is_packed):
106
+ tag_size = _TagSize(field_number)
107
+ if is_packed:
108
+ local_VarintSize = _VarintSize
109
+ def PackedFieldSize(value):
110
+ result = 0
111
+ for element in value:
112
+ result += compute_value_size(element)
113
+ return result + local_VarintSize(result) + tag_size
114
+ return PackedFieldSize
115
+ elif is_repeated:
116
+ def RepeatedFieldSize(value):
117
+ result = tag_size * len(value)
118
+ for element in value:
119
+ result += compute_value_size(element)
120
+ return result
121
+ return RepeatedFieldSize
122
+ else:
123
+ def FieldSize(value):
124
+ return tag_size + compute_value_size(value)
125
+ return FieldSize
126
+
127
+ return SpecificSizer
128
+
129
+
130
+ def _ModifiedSizer(compute_value_size, modify_value):
131
+ """Like SimpleSizer, but modify_value is invoked on each value before it is
132
+ passed to compute_value_size. modify_value is typically ZigZagEncode."""
133
+
134
+ def SpecificSizer(field_number, is_repeated, is_packed):
135
+ tag_size = _TagSize(field_number)
136
+ if is_packed:
137
+ local_VarintSize = _VarintSize
138
+ def PackedFieldSize(value):
139
+ result = 0
140
+ for element in value:
141
+ result += compute_value_size(modify_value(element))
142
+ return result + local_VarintSize(result) + tag_size
143
+ return PackedFieldSize
144
+ elif is_repeated:
145
+ def RepeatedFieldSize(value):
146
+ result = tag_size * len(value)
147
+ for element in value:
148
+ result += compute_value_size(modify_value(element))
149
+ return result
150
+ return RepeatedFieldSize
151
+ else:
152
+ def FieldSize(value):
153
+ return tag_size + compute_value_size(modify_value(value))
154
+ return FieldSize
155
+
156
+ return SpecificSizer
157
+
158
+
159
+ def _FixedSizer(value_size):
160
+ """Like _SimpleSizer except for a fixed-size field. The input is the size
161
+ of one value."""
162
+
163
+ def SpecificSizer(field_number, is_repeated, is_packed):
164
+ tag_size = _TagSize(field_number)
165
+ if is_packed:
166
+ local_VarintSize = _VarintSize
167
+ def PackedFieldSize(value):
168
+ result = len(value) * value_size
169
+ return result + local_VarintSize(result) + tag_size
170
+ return PackedFieldSize
171
+ elif is_repeated:
172
+ element_size = value_size + tag_size
173
+ def RepeatedFieldSize(value):
174
+ return len(value) * element_size
175
+ return RepeatedFieldSize
176
+ else:
177
+ field_size = value_size + tag_size
178
+ def FieldSize(value):
179
+ return field_size
180
+ return FieldSize
181
+
182
+ return SpecificSizer
183
+
184
+
185
+ # ====================================================================
186
+ # Here we declare a sizer constructor for each field type. Each "sizer
187
+ # constructor" is a function that takes (field_number, is_repeated, is_packed)
188
+ # as parameters and returns a sizer, which in turn takes a field value as
189
+ # a parameter and returns its encoded size.
190
+
191
+
192
+ Int32Sizer = Int64Sizer = EnumSizer = _SimpleSizer(_SignedVarintSize)
193
+
194
+ UInt32Sizer = UInt64Sizer = _SimpleSizer(_VarintSize)
195
+
196
+ SInt32Sizer = SInt64Sizer = _ModifiedSizer(
197
+ _SignedVarintSize, wire_format.ZigZagEncode)
198
+
199
+ Fixed32Sizer = SFixed32Sizer = FloatSizer = _FixedSizer(4)
200
+ Fixed64Sizer = SFixed64Sizer = DoubleSizer = _FixedSizer(8)
201
+
202
+ BoolSizer = _FixedSizer(1)
203
+
204
+
205
+ def StringSizer(field_number, is_repeated, is_packed):
206
+ """Returns a sizer for a string field."""
207
+
208
+ tag_size = _TagSize(field_number)
209
+ local_VarintSize = _VarintSize
210
+ local_len = len
211
+ assert not is_packed
212
+ if is_repeated:
213
+ def RepeatedFieldSize(value):
214
+ result = tag_size * len(value)
215
+ for element in value:
216
+ l = local_len(element.encode('utf-8'))
217
+ result += local_VarintSize(l) + l
218
+ return result
219
+ return RepeatedFieldSize
220
+ else:
221
+ def FieldSize(value):
222
+ l = local_len(value.encode('utf-8'))
223
+ return tag_size + local_VarintSize(l) + l
224
+ return FieldSize
225
+
226
+
227
+ def BytesSizer(field_number, is_repeated, is_packed):
228
+ """Returns a sizer for a bytes field."""
229
+
230
+ tag_size = _TagSize(field_number)
231
+ local_VarintSize = _VarintSize
232
+ local_len = len
233
+ assert not is_packed
234
+ if is_repeated:
235
+ def RepeatedFieldSize(value):
236
+ result = tag_size * len(value)
237
+ for element in value:
238
+ l = local_len(element)
239
+ result += local_VarintSize(l) + l
240
+ return result
241
+ return RepeatedFieldSize
242
+ else:
243
+ def FieldSize(value):
244
+ l = local_len(value)
245
+ return tag_size + local_VarintSize(l) + l
246
+ return FieldSize
247
+
248
+
249
+ def GroupSizer(field_number, is_repeated, is_packed):
250
+ """Returns a sizer for a group field."""
251
+
252
+ tag_size = _TagSize(field_number) * 2
253
+ assert not is_packed
254
+ if is_repeated:
255
+ def RepeatedFieldSize(value):
256
+ result = tag_size * len(value)
257
+ for element in value:
258
+ result += element.ByteSize()
259
+ return result
260
+ return RepeatedFieldSize
261
+ else:
262
+ def FieldSize(value):
263
+ return tag_size + value.ByteSize()
264
+ return FieldSize
265
+
266
+
267
+ def MessageSizer(field_number, is_repeated, is_packed):
268
+ """Returns a sizer for a message field."""
269
+
270
+ tag_size = _TagSize(field_number)
271
+ local_VarintSize = _VarintSize
272
+ assert not is_packed
273
+ if is_repeated:
274
+ def RepeatedFieldSize(value):
275
+ result = tag_size * len(value)
276
+ for element in value:
277
+ l = element.ByteSize()
278
+ result += local_VarintSize(l) + l
279
+ return result
280
+ return RepeatedFieldSize
281
+ else:
282
+ def FieldSize(value):
283
+ l = value.ByteSize()
284
+ return tag_size + local_VarintSize(l) + l
285
+ return FieldSize
286
+
287
+
288
+ # --------------------------------------------------------------------
289
+ # MessageSet is special: it needs custom logic to compute its size properly.
290
+
291
+
292
+ def MessageSetItemSizer(field_number):
293
+ """Returns a sizer for extensions of MessageSet.
294
+
295
+ The message set message looks like this:
296
+ message MessageSet {
297
+ repeated group Item = 1 {
298
+ required int32 type_id = 2;
299
+ required string message = 3;
300
+ }
301
+ }
302
+ """
303
+ static_size = (_TagSize(1) * 2 + _TagSize(2) + _VarintSize(field_number) +
304
+ _TagSize(3))
305
+ local_VarintSize = _VarintSize
306
+
307
+ def FieldSize(value):
308
+ l = value.ByteSize()
309
+ return static_size + local_VarintSize(l) + l
310
+
311
+ return FieldSize
312
+
313
+
314
+ # --------------------------------------------------------------------
315
+ # Map is special: it needs custom logic to compute its size properly.
316
+
317
+
318
+ def MapSizer(field_descriptor, is_message_map):
319
+ """Returns a sizer for a map field."""
320
+
321
+ # Can't look at field_descriptor.message_type._concrete_class because it may
322
+ # not have been initialized yet.
323
+ message_type = field_descriptor.message_type
324
+ message_sizer = MessageSizer(field_descriptor.number, False, False)
325
+
326
+ def FieldSize(map_value):
327
+ total = 0
328
+ for key in map_value:
329
+ value = map_value[key]
330
+ # It's wasteful to create the messages and throw them away one second
331
+ # later since we'll do the same for the actual encode. But there's not an
332
+ # obvious way to avoid this within the current design without tons of code
333
+ # duplication. For message map, value.ByteSize() should be called to
334
+ # update the status.
335
+ entry_msg = message_type._concrete_class(key=key, value=value)
336
+ total += message_sizer(entry_msg)
337
+ if is_message_map:
338
+ value.ByteSize()
339
+ return total
340
+
341
+ return FieldSize
342
+
343
+ # ====================================================================
344
+ # Encoders!
345
+
346
+
347
+ def _VarintEncoder():
348
+ """Return an encoder for a basic varint value (does not include tag)."""
349
+
350
+ local_int2byte = struct.Struct('>B').pack
351
+
352
+ def EncodeVarint(write, value, unused_deterministic=None):
353
+ bits = value & 0x7f
354
+ value >>= 7
355
+ while value:
356
+ write(local_int2byte(0x80|bits))
357
+ bits = value & 0x7f
358
+ value >>= 7
359
+ return write(local_int2byte(bits))
360
+
361
+ return EncodeVarint
362
+
363
+
364
+ def _SignedVarintEncoder():
365
+ """Return an encoder for a basic signed varint value (does not include
366
+ tag)."""
367
+
368
+ local_int2byte = struct.Struct('>B').pack
369
+
370
+ def EncodeSignedVarint(write, value, unused_deterministic=None):
371
+ if value < 0:
372
+ value += (1 << 64)
373
+ bits = value & 0x7f
374
+ value >>= 7
375
+ while value:
376
+ write(local_int2byte(0x80|bits))
377
+ bits = value & 0x7f
378
+ value >>= 7
379
+ return write(local_int2byte(bits))
380
+
381
+ return EncodeSignedVarint
382
+
383
+
384
+ _EncodeVarint = _VarintEncoder()
385
+ _EncodeSignedVarint = _SignedVarintEncoder()
386
+
387
+
388
+ def _VarintBytes(value):
389
+ """Encode the given integer as a varint and return the bytes. This is only
390
+ called at startup time so it doesn't need to be fast."""
391
+
392
+ pieces = []
393
+ _EncodeVarint(pieces.append, value, True)
394
+ return b"".join(pieces)
395
+
396
+
397
+ def TagBytes(field_number, wire_type):
398
+ """Encode the given tag and return the bytes. Only called at startup."""
399
+
400
+ return bytes(_VarintBytes(wire_format.PackTag(field_number, wire_type)))
401
+
402
+ # --------------------------------------------------------------------
403
+ # As with sizers (see above), we have a number of common encoder
404
+ # implementations.
405
+
406
+
407
+ def _SimpleEncoder(wire_type, encode_value, compute_value_size):
408
+ """Return a constructor for an encoder for fields of a particular type.
409
+
410
+ Args:
411
+ wire_type: The field's wire type, for encoding tags.
412
+ encode_value: A function which encodes an individual value, e.g.
413
+ _EncodeVarint().
414
+ compute_value_size: A function which computes the size of an individual
415
+ value, e.g. _VarintSize().
416
+ """
417
+
418
+ def SpecificEncoder(field_number, is_repeated, is_packed):
419
+ if is_packed:
420
+ tag_bytes = TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED)
421
+ local_EncodeVarint = _EncodeVarint
422
+ def EncodePackedField(write, value, deterministic):
423
+ write(tag_bytes)
424
+ size = 0
425
+ for element in value:
426
+ size += compute_value_size(element)
427
+ local_EncodeVarint(write, size, deterministic)
428
+ for element in value:
429
+ encode_value(write, element, deterministic)
430
+ return EncodePackedField
431
+ elif is_repeated:
432
+ tag_bytes = TagBytes(field_number, wire_type)
433
+ def EncodeRepeatedField(write, value, deterministic):
434
+ for element in value:
435
+ write(tag_bytes)
436
+ encode_value(write, element, deterministic)
437
+ return EncodeRepeatedField
438
+ else:
439
+ tag_bytes = TagBytes(field_number, wire_type)
440
+ def EncodeField(write, value, deterministic):
441
+ write(tag_bytes)
442
+ return encode_value(write, value, deterministic)
443
+ return EncodeField
444
+
445
+ return SpecificEncoder
446
+
447
+
448
+ def _ModifiedEncoder(wire_type, encode_value, compute_value_size, modify_value):
449
+ """Like SimpleEncoder but additionally invokes modify_value on every value
450
+ before passing it to encode_value. Usually modify_value is ZigZagEncode."""
451
+
452
+ def SpecificEncoder(field_number, is_repeated, is_packed):
453
+ if is_packed:
454
+ tag_bytes = TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED)
455
+ local_EncodeVarint = _EncodeVarint
456
+ def EncodePackedField(write, value, deterministic):
457
+ write(tag_bytes)
458
+ size = 0
459
+ for element in value:
460
+ size += compute_value_size(modify_value(element))
461
+ local_EncodeVarint(write, size, deterministic)
462
+ for element in value:
463
+ encode_value(write, modify_value(element), deterministic)
464
+ return EncodePackedField
465
+ elif is_repeated:
466
+ tag_bytes = TagBytes(field_number, wire_type)
467
+ def EncodeRepeatedField(write, value, deterministic):
468
+ for element in value:
469
+ write(tag_bytes)
470
+ encode_value(write, modify_value(element), deterministic)
471
+ return EncodeRepeatedField
472
+ else:
473
+ tag_bytes = TagBytes(field_number, wire_type)
474
+ def EncodeField(write, value, deterministic):
475
+ write(tag_bytes)
476
+ return encode_value(write, modify_value(value), deterministic)
477
+ return EncodeField
478
+
479
+ return SpecificEncoder
480
+
481
+
482
+ def _StructPackEncoder(wire_type, format):
483
+ """Return a constructor for an encoder for a fixed-width field.
484
+
485
+ Args:
486
+ wire_type: The field's wire type, for encoding tags.
487
+ format: The format string to pass to struct.pack().
488
+ """
489
+
490
+ value_size = struct.calcsize(format)
491
+
492
+ def SpecificEncoder(field_number, is_repeated, is_packed):
493
+ local_struct_pack = struct.pack
494
+ if is_packed:
495
+ tag_bytes = TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED)
496
+ local_EncodeVarint = _EncodeVarint
497
+ def EncodePackedField(write, value, deterministic):
498
+ write(tag_bytes)
499
+ local_EncodeVarint(write, len(value) * value_size, deterministic)
500
+ for element in value:
501
+ write(local_struct_pack(format, element))
502
+ return EncodePackedField
503
+ elif is_repeated:
504
+ tag_bytes = TagBytes(field_number, wire_type)
505
+ def EncodeRepeatedField(write, value, unused_deterministic=None):
506
+ for element in value:
507
+ write(tag_bytes)
508
+ write(local_struct_pack(format, element))
509
+ return EncodeRepeatedField
510
+ else:
511
+ tag_bytes = TagBytes(field_number, wire_type)
512
+ def EncodeField(write, value, unused_deterministic=None):
513
+ write(tag_bytes)
514
+ return write(local_struct_pack(format, value))
515
+ return EncodeField
516
+
517
+ return SpecificEncoder
518
+
519
+
520
+ def _FloatingPointEncoder(wire_type, format):
521
+ """Return a constructor for an encoder for float fields.
522
+
523
+ This is like StructPackEncoder, but catches errors that may be due to
524
+ passing non-finite floating-point values to struct.pack, and makes a
525
+ second attempt to encode those values.
526
+
527
+ Args:
528
+ wire_type: The field's wire type, for encoding tags.
529
+ format: The format string to pass to struct.pack().
530
+ """
531
+
532
+ value_size = struct.calcsize(format)
533
+ if value_size == 4:
534
+ def EncodeNonFiniteOrRaise(write, value):
535
+ # Remember that the serialized form uses little-endian byte order.
536
+ if value == _POS_INF:
537
+ write(b'\x00\x00\x80\x7F')
538
+ elif value == _NEG_INF:
539
+ write(b'\x00\x00\x80\xFF')
540
+ elif value != value: # NaN
541
+ write(b'\x00\x00\xC0\x7F')
542
+ else:
543
+ raise
544
+ elif value_size == 8:
545
+ def EncodeNonFiniteOrRaise(write, value):
546
+ if value == _POS_INF:
547
+ write(b'\x00\x00\x00\x00\x00\x00\xF0\x7F')
548
+ elif value == _NEG_INF:
549
+ write(b'\x00\x00\x00\x00\x00\x00\xF0\xFF')
550
+ elif value != value: # NaN
551
+ write(b'\x00\x00\x00\x00\x00\x00\xF8\x7F')
552
+ else:
553
+ raise
554
+ else:
555
+ raise ValueError('Can\'t encode floating-point values that are '
556
+ '%d bytes long (only 4 or 8)' % value_size)
557
+
558
+ def SpecificEncoder(field_number, is_repeated, is_packed):
559
+ local_struct_pack = struct.pack
560
+ if is_packed:
561
+ tag_bytes = TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED)
562
+ local_EncodeVarint = _EncodeVarint
563
+ def EncodePackedField(write, value, deterministic):
564
+ write(tag_bytes)
565
+ local_EncodeVarint(write, len(value) * value_size, deterministic)
566
+ for element in value:
567
+ # This try/except block is going to be faster than any code that
568
+ # we could write to check whether element is finite.
569
+ try:
570
+ write(local_struct_pack(format, element))
571
+ except SystemError:
572
+ EncodeNonFiniteOrRaise(write, element)
573
+ return EncodePackedField
574
+ elif is_repeated:
575
+ tag_bytes = TagBytes(field_number, wire_type)
576
+ def EncodeRepeatedField(write, value, unused_deterministic=None):
577
+ for element in value:
578
+ write(tag_bytes)
579
+ try:
580
+ write(local_struct_pack(format, element))
581
+ except SystemError:
582
+ EncodeNonFiniteOrRaise(write, element)
583
+ return EncodeRepeatedField
584
+ else:
585
+ tag_bytes = TagBytes(field_number, wire_type)
586
+ def EncodeField(write, value, unused_deterministic=None):
587
+ write(tag_bytes)
588
+ try:
589
+ write(local_struct_pack(format, value))
590
+ except SystemError:
591
+ EncodeNonFiniteOrRaise(write, value)
592
+ return EncodeField
593
+
594
+ return SpecificEncoder
595
+
596
+
597
+ # ====================================================================
598
+ # Here we declare an encoder constructor for each field type. These work
599
+ # very similarly to sizer constructors, described earlier.
600
+
601
+
602
+ Int32Encoder = Int64Encoder = EnumEncoder = _SimpleEncoder(
603
+ wire_format.WIRETYPE_VARINT, _EncodeSignedVarint, _SignedVarintSize)
604
+
605
+ UInt32Encoder = UInt64Encoder = _SimpleEncoder(
606
+ wire_format.WIRETYPE_VARINT, _EncodeVarint, _VarintSize)
607
+
608
+ SInt32Encoder = SInt64Encoder = _ModifiedEncoder(
609
+ wire_format.WIRETYPE_VARINT, _EncodeVarint, _VarintSize,
610
+ wire_format.ZigZagEncode)
611
+
612
+ # Note that Python conveniently guarantees that when using the '<' prefix on
613
+ # formats, they will also have the same size across all platforms (as opposed
614
+ # to without the prefix, where their sizes depend on the C compiler's basic
615
+ # type sizes).
616
+ Fixed32Encoder = _StructPackEncoder(wire_format.WIRETYPE_FIXED32, '<I')
617
+ Fixed64Encoder = _StructPackEncoder(wire_format.WIRETYPE_FIXED64, '<Q')
618
+ SFixed32Encoder = _StructPackEncoder(wire_format.WIRETYPE_FIXED32, '<i')
619
+ SFixed64Encoder = _StructPackEncoder(wire_format.WIRETYPE_FIXED64, '<q')
620
+ FloatEncoder = _FloatingPointEncoder(wire_format.WIRETYPE_FIXED32, '<f')
621
+ DoubleEncoder = _FloatingPointEncoder(wire_format.WIRETYPE_FIXED64, '<d')
622
+
623
+
624
+ def BoolEncoder(field_number, is_repeated, is_packed):
625
+ """Returns an encoder for a boolean field."""
626
+
627
+ false_byte = b'\x00'
628
+ true_byte = b'\x01'
629
+ if is_packed:
630
+ tag_bytes = TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED)
631
+ local_EncodeVarint = _EncodeVarint
632
+ def EncodePackedField(write, value, deterministic):
633
+ write(tag_bytes)
634
+ local_EncodeVarint(write, len(value), deterministic)
635
+ for element in value:
636
+ if element:
637
+ write(true_byte)
638
+ else:
639
+ write(false_byte)
640
+ return EncodePackedField
641
+ elif is_repeated:
642
+ tag_bytes = TagBytes(field_number, wire_format.WIRETYPE_VARINT)
643
+ def EncodeRepeatedField(write, value, unused_deterministic=None):
644
+ for element in value:
645
+ write(tag_bytes)
646
+ if element:
647
+ write(true_byte)
648
+ else:
649
+ write(false_byte)
650
+ return EncodeRepeatedField
651
+ else:
652
+ tag_bytes = TagBytes(field_number, wire_format.WIRETYPE_VARINT)
653
+ def EncodeField(write, value, unused_deterministic=None):
654
+ write(tag_bytes)
655
+ if value:
656
+ return write(true_byte)
657
+ return write(false_byte)
658
+ return EncodeField
659
+
660
+
661
+ def StringEncoder(field_number, is_repeated, is_packed):
662
+ """Returns an encoder for a string field."""
663
+
664
+ tag = TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED)
665
+ local_EncodeVarint = _EncodeVarint
666
+ local_len = len
667
+ assert not is_packed
668
+ if is_repeated:
669
+ def EncodeRepeatedField(write, value, deterministic):
670
+ for element in value:
671
+ encoded = element.encode('utf-8')
672
+ write(tag)
673
+ local_EncodeVarint(write, local_len(encoded), deterministic)
674
+ write(encoded)
675
+ return EncodeRepeatedField
676
+ else:
677
+ def EncodeField(write, value, deterministic):
678
+ encoded = value.encode('utf-8')
679
+ write(tag)
680
+ local_EncodeVarint(write, local_len(encoded), deterministic)
681
+ return write(encoded)
682
+ return EncodeField
683
+
684
+
685
+ def BytesEncoder(field_number, is_repeated, is_packed):
686
+ """Returns an encoder for a bytes field."""
687
+
688
+ tag = TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED)
689
+ local_EncodeVarint = _EncodeVarint
690
+ local_len = len
691
+ assert not is_packed
692
+ if is_repeated:
693
+ def EncodeRepeatedField(write, value, deterministic):
694
+ for element in value:
695
+ write(tag)
696
+ local_EncodeVarint(write, local_len(element), deterministic)
697
+ write(element)
698
+ return EncodeRepeatedField
699
+ else:
700
+ def EncodeField(write, value, deterministic):
701
+ write(tag)
702
+ local_EncodeVarint(write, local_len(value), deterministic)
703
+ return write(value)
704
+ return EncodeField
705
+
706
+
707
+ def GroupEncoder(field_number, is_repeated, is_packed):
708
+ """Returns an encoder for a group field."""
709
+
710
+ start_tag = TagBytes(field_number, wire_format.WIRETYPE_START_GROUP)
711
+ end_tag = TagBytes(field_number, wire_format.WIRETYPE_END_GROUP)
712
+ assert not is_packed
713
+ if is_repeated:
714
+ def EncodeRepeatedField(write, value, deterministic):
715
+ for element in value:
716
+ write(start_tag)
717
+ element._InternalSerialize(write, deterministic)
718
+ write(end_tag)
719
+ return EncodeRepeatedField
720
+ else:
721
+ def EncodeField(write, value, deterministic):
722
+ write(start_tag)
723
+ value._InternalSerialize(write, deterministic)
724
+ return write(end_tag)
725
+ return EncodeField
726
+
727
+
728
+ def MessageEncoder(field_number, is_repeated, is_packed):
729
+ """Returns an encoder for a message field."""
730
+
731
+ tag = TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED)
732
+ local_EncodeVarint = _EncodeVarint
733
+ assert not is_packed
734
+ if is_repeated:
735
+ def EncodeRepeatedField(write, value, deterministic):
736
+ for element in value:
737
+ write(tag)
738
+ local_EncodeVarint(write, element.ByteSize(), deterministic)
739
+ element._InternalSerialize(write, deterministic)
740
+ return EncodeRepeatedField
741
+ else:
742
+ def EncodeField(write, value, deterministic):
743
+ write(tag)
744
+ local_EncodeVarint(write, value.ByteSize(), deterministic)
745
+ return value._InternalSerialize(write, deterministic)
746
+ return EncodeField
747
+
748
+
749
+ # --------------------------------------------------------------------
750
+ # As before, MessageSet is special.
751
+
752
+
753
+ def MessageSetItemEncoder(field_number):
754
+ """Encoder for extensions of MessageSet.
755
+
756
+ The message set message looks like this:
757
+ message MessageSet {
758
+ repeated group Item = 1 {
759
+ required int32 type_id = 2;
760
+ required string message = 3;
761
+ }
762
+ }
763
+ """
764
+ start_bytes = b"".join([
765
+ TagBytes(1, wire_format.WIRETYPE_START_GROUP),
766
+ TagBytes(2, wire_format.WIRETYPE_VARINT),
767
+ _VarintBytes(field_number),
768
+ TagBytes(3, wire_format.WIRETYPE_LENGTH_DELIMITED)])
769
+ end_bytes = TagBytes(1, wire_format.WIRETYPE_END_GROUP)
770
+ local_EncodeVarint = _EncodeVarint
771
+
772
+ def EncodeField(write, value, deterministic):
773
+ write(start_bytes)
774
+ local_EncodeVarint(write, value.ByteSize(), deterministic)
775
+ value._InternalSerialize(write, deterministic)
776
+ return write(end_bytes)
777
+
778
+ return EncodeField
779
+
780
+
781
+ # --------------------------------------------------------------------
782
+ # As before, Map is special.
783
+
784
+
785
+ def MapEncoder(field_descriptor):
786
+ """Encoder for extensions of MessageSet.
787
+
788
+ Maps always have a wire format like this:
789
+ message MapEntry {
790
+ key_type key = 1;
791
+ value_type value = 2;
792
+ }
793
+ repeated MapEntry map = N;
794
+ """
795
+ # Can't look at field_descriptor.message_type._concrete_class because it may
796
+ # not have been initialized yet.
797
+ message_type = field_descriptor.message_type
798
+ encode_message = MessageEncoder(field_descriptor.number, False, False)
799
+
800
+ def EncodeField(write, value, deterministic):
801
+ value_keys = sorted(value.keys()) if deterministic else value
802
+ for key in value_keys:
803
+ entry_msg = message_type._concrete_class(key=key, value=value[key])
804
+ encode_message(write, entry_msg, deterministic)
805
+
806
+ return EncodeField
.venv/lib/python3.11/site-packages/google/protobuf/internal/python_edition_defaults.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ """
2
+ This file contains the serialized FeatureSetDefaults object corresponding to
3
+ the Pure Python runtime. This is used for feature resolution under Editions.
4
+ """
5
+ _PROTOBUF_INTERNAL_PYTHON_EDITION_DEFAULTS = b"\n\023\030\204\007\"\000*\014\010\001\020\002\030\002 \003(\0010\002\n\023\030\347\007\"\000*\014\010\002\020\001\030\001 \002(\0010\001\n\023\030\350\007\"\014\010\001\020\001\030\001 \002(\0010\001*\000 \346\007(\350\007"
.venv/lib/python3.11/site-packages/google/protobuf/internal/testing_refleaks.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Protocol Buffers - Google's data interchange format
2
+ # Copyright 2008 Google Inc. All rights reserved.
3
+ #
4
+ # Use of this source code is governed by a BSD-style
5
+ # license that can be found in the LICENSE file or at
6
+ # https://developers.google.com/open-source/licenses/bsd
7
+
8
+ """A subclass of unittest.TestCase which checks for reference leaks.
9
+
10
+ To use:
11
+ - Use testing_refleak.BaseTestCase instead of unittest.TestCase
12
+ - Configure and compile Python with --with-pydebug
13
+
14
+ If sys.gettotalrefcount() is not available (because Python was built without
15
+ the Py_DEBUG option), then this module is a no-op and tests will run normally.
16
+ """
17
+
18
+ import copyreg
19
+ import gc
20
+ import sys
21
+ import unittest
22
+
23
+
24
+ class LocalTestResult(unittest.TestResult):
25
+ """A TestResult which forwards events to a parent object, except for Skips."""
26
+
27
+ def __init__(self, parent_result):
28
+ unittest.TestResult.__init__(self)
29
+ self.parent_result = parent_result
30
+
31
+ def addError(self, test, error):
32
+ self.parent_result.addError(test, error)
33
+
34
+ def addFailure(self, test, error):
35
+ self.parent_result.addFailure(test, error)
36
+
37
+ def addSkip(self, test, reason):
38
+ pass
39
+
40
+
41
+ class ReferenceLeakCheckerMixin(object):
42
+ """A mixin class for TestCase, which checks reference counts."""
43
+
44
+ NB_RUNS = 3
45
+
46
+ def run(self, result=None):
47
+ testMethod = getattr(self, self._testMethodName)
48
+ expecting_failure_method = getattr(testMethod, "__unittest_expecting_failure__", False)
49
+ expecting_failure_class = getattr(self, "__unittest_expecting_failure__", False)
50
+ if expecting_failure_class or expecting_failure_method:
51
+ return
52
+
53
+ # python_message.py registers all Message classes to some pickle global
54
+ # registry, which makes the classes immortal.
55
+ # We save a copy of this registry, and reset it before we could references.
56
+ self._saved_pickle_registry = copyreg.dispatch_table.copy()
57
+
58
+ # Run the test twice, to warm up the instance attributes.
59
+ super(ReferenceLeakCheckerMixin, self).run(result=result)
60
+ super(ReferenceLeakCheckerMixin, self).run(result=result)
61
+
62
+ oldrefcount = 0
63
+ local_result = LocalTestResult(result)
64
+ num_flakes = 0
65
+
66
+ refcount_deltas = []
67
+ while len(refcount_deltas) < self.NB_RUNS:
68
+ oldrefcount = self._getRefcounts()
69
+ super(ReferenceLeakCheckerMixin, self).run(result=local_result)
70
+ newrefcount = self._getRefcounts()
71
+ # If the GC was able to collect some objects after the call to run() that
72
+ # it could not collect before the call, then the counts won't match.
73
+ if newrefcount < oldrefcount and num_flakes < 2:
74
+ # This result is (probably) a flake -- garbage collectors aren't very
75
+ # predictable, but a lower ending refcount is the opposite of the
76
+ # failure we are testing for. If the result is repeatable, then we will
77
+ # eventually report it, but not after trying to eliminate it.
78
+ num_flakes += 1
79
+ continue
80
+ num_flakes = 0
81
+ refcount_deltas.append(newrefcount - oldrefcount)
82
+ print(refcount_deltas, self)
83
+
84
+ try:
85
+ self.assertEqual(refcount_deltas, [0] * self.NB_RUNS)
86
+ except Exception: # pylint: disable=broad-except
87
+ result.addError(self, sys.exc_info())
88
+
89
+ def _getRefcounts(self):
90
+ copyreg.dispatch_table.clear()
91
+ copyreg.dispatch_table.update(self._saved_pickle_registry)
92
+ # It is sometimes necessary to gc.collect() multiple times, to ensure
93
+ # that all objects can be collected.
94
+ gc.collect()
95
+ gc.collect()
96
+ gc.collect()
97
+ return sys.gettotalrefcount()
98
+
99
+
100
+ if hasattr(sys, 'gettotalrefcount'):
101
+
102
+ def TestCase(test_class):
103
+ new_bases = (ReferenceLeakCheckerMixin,) + test_class.__bases__
104
+ new_class = type(test_class)(
105
+ test_class.__name__, new_bases, dict(test_class.__dict__))
106
+ return new_class
107
+ SkipReferenceLeakChecker = unittest.skip
108
+
109
+ else:
110
+ # When PyDEBUG is not enabled, run the tests normally.
111
+
112
+ def TestCase(test_class):
113
+ return test_class
114
+
115
+ def SkipReferenceLeakChecker(reason):
116
+ del reason # Don't skip, so don't need a reason.
117
+ def Same(func):
118
+ return func
119
+ return Same
.venv/lib/python3.11/site-packages/google/protobuf/internal/well_known_types.py ADDED
@@ -0,0 +1,678 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Protocol Buffers - Google's data interchange format
2
+ # Copyright 2008 Google Inc. All rights reserved.
3
+ #
4
+ # Use of this source code is governed by a BSD-style
5
+ # license that can be found in the LICENSE file or at
6
+ # https://developers.google.com/open-source/licenses/bsd
7
+
8
+ """Contains well known classes.
9
+
10
+ This files defines well known classes which need extra maintenance including:
11
+ - Any
12
+ - Duration
13
+ - FieldMask
14
+ - Struct
15
+ - Timestamp
16
+ """
17
+
18
+ __author__ = 'jieluo@google.com (Jie Luo)'
19
+
20
+ import calendar
21
+ import collections.abc
22
+ import datetime
23
+ import warnings
24
+ from google.protobuf.internal import field_mask
25
+ from typing import Union
26
+
27
+ FieldMask = field_mask.FieldMask
28
+
29
+ _TIMESTAMPFOMAT = '%Y-%m-%dT%H:%M:%S'
30
+ _NANOS_PER_SECOND = 1000000000
31
+ _NANOS_PER_MILLISECOND = 1000000
32
+ _NANOS_PER_MICROSECOND = 1000
33
+ _MILLIS_PER_SECOND = 1000
34
+ _MICROS_PER_SECOND = 1000000
35
+ _SECONDS_PER_DAY = 24 * 3600
36
+ _DURATION_SECONDS_MAX = 315576000000
37
+ _TIMESTAMP_SECONDS_MIN = -62135596800
38
+ _TIMESTAMP_SECONDS_MAX = 253402300799
39
+
40
+ _EPOCH_DATETIME_NAIVE = datetime.datetime(1970, 1, 1, tzinfo=None)
41
+ _EPOCH_DATETIME_AWARE = _EPOCH_DATETIME_NAIVE.replace(
42
+ tzinfo=datetime.timezone.utc
43
+ )
44
+
45
+
46
+ class Any(object):
47
+ """Class for Any Message type."""
48
+
49
+ __slots__ = ()
50
+
51
+ def Pack(self, msg, type_url_prefix='type.googleapis.com/',
52
+ deterministic=None):
53
+ """Packs the specified message into current Any message."""
54
+ if len(type_url_prefix) < 1 or type_url_prefix[-1] != '/':
55
+ self.type_url = '%s/%s' % (type_url_prefix, msg.DESCRIPTOR.full_name)
56
+ else:
57
+ self.type_url = '%s%s' % (type_url_prefix, msg.DESCRIPTOR.full_name)
58
+ self.value = msg.SerializeToString(deterministic=deterministic)
59
+
60
+ def Unpack(self, msg):
61
+ """Unpacks the current Any message into specified message."""
62
+ descriptor = msg.DESCRIPTOR
63
+ if not self.Is(descriptor):
64
+ return False
65
+ msg.ParseFromString(self.value)
66
+ return True
67
+
68
+ def TypeName(self):
69
+ """Returns the protobuf type name of the inner message."""
70
+ # Only last part is to be used: b/25630112
71
+ return self.type_url.split('/')[-1]
72
+
73
+ def Is(self, descriptor):
74
+ """Checks if this Any represents the given protobuf type."""
75
+ return '/' in self.type_url and self.TypeName() == descriptor.full_name
76
+
77
+
78
+ class Timestamp(object):
79
+ """Class for Timestamp message type."""
80
+
81
+ __slots__ = ()
82
+
83
+ def ToJsonString(self):
84
+ """Converts Timestamp to RFC 3339 date string format.
85
+
86
+ Returns:
87
+ A string converted from timestamp. The string is always Z-normalized
88
+ and uses 3, 6 or 9 fractional digits as required to represent the
89
+ exact time. Example of the return format: '1972-01-01T10:00:20.021Z'
90
+ """
91
+ _CheckTimestampValid(self.seconds, self.nanos)
92
+ nanos = self.nanos
93
+ seconds = self.seconds % _SECONDS_PER_DAY
94
+ days = (self.seconds - seconds) // _SECONDS_PER_DAY
95
+ dt = datetime.datetime(1970, 1, 1) + datetime.timedelta(days, seconds)
96
+
97
+ result = dt.isoformat()
98
+ if (nanos % 1e9) == 0:
99
+ # If there are 0 fractional digits, the fractional
100
+ # point '.' should be omitted when serializing.
101
+ return result + 'Z'
102
+ if (nanos % 1e6) == 0:
103
+ # Serialize 3 fractional digits.
104
+ return result + '.%03dZ' % (nanos / 1e6)
105
+ if (nanos % 1e3) == 0:
106
+ # Serialize 6 fractional digits.
107
+ return result + '.%06dZ' % (nanos / 1e3)
108
+ # Serialize 9 fractional digits.
109
+ return result + '.%09dZ' % nanos
110
+
111
+ def FromJsonString(self, value):
112
+ """Parse a RFC 3339 date string format to Timestamp.
113
+
114
+ Args:
115
+ value: A date string. Any fractional digits (or none) and any offset are
116
+ accepted as long as they fit into nano-seconds precision.
117
+ Example of accepted format: '1972-01-01T10:00:20.021-05:00'
118
+
119
+ Raises:
120
+ ValueError: On parsing problems.
121
+ """
122
+ if not isinstance(value, str):
123
+ raise ValueError('Timestamp JSON value not a string: {!r}'.format(value))
124
+ timezone_offset = value.find('Z')
125
+ if timezone_offset == -1:
126
+ timezone_offset = value.find('+')
127
+ if timezone_offset == -1:
128
+ timezone_offset = value.rfind('-')
129
+ if timezone_offset == -1:
130
+ raise ValueError(
131
+ 'Failed to parse timestamp: missing valid timezone offset.')
132
+ time_value = value[0:timezone_offset]
133
+ # Parse datetime and nanos.
134
+ point_position = time_value.find('.')
135
+ if point_position == -1:
136
+ second_value = time_value
137
+ nano_value = ''
138
+ else:
139
+ second_value = time_value[:point_position]
140
+ nano_value = time_value[point_position + 1:]
141
+ if 't' in second_value:
142
+ raise ValueError(
143
+ 'time data \'{0}\' does not match format \'%Y-%m-%dT%H:%M:%S\', '
144
+ 'lowercase \'t\' is not accepted'.format(second_value))
145
+ date_object = datetime.datetime.strptime(second_value, _TIMESTAMPFOMAT)
146
+ td = date_object - datetime.datetime(1970, 1, 1)
147
+ seconds = td.seconds + td.days * _SECONDS_PER_DAY
148
+ if len(nano_value) > 9:
149
+ raise ValueError(
150
+ 'Failed to parse Timestamp: nanos {0} more than '
151
+ '9 fractional digits.'.format(nano_value))
152
+ if nano_value:
153
+ nanos = round(float('0.' + nano_value) * 1e9)
154
+ else:
155
+ nanos = 0
156
+ # Parse timezone offsets.
157
+ if value[timezone_offset] == 'Z':
158
+ if len(value) != timezone_offset + 1:
159
+ raise ValueError('Failed to parse timestamp: invalid trailing'
160
+ ' data {0}.'.format(value))
161
+ else:
162
+ timezone = value[timezone_offset:]
163
+ pos = timezone.find(':')
164
+ if pos == -1:
165
+ raise ValueError(
166
+ 'Invalid timezone offset value: {0}.'.format(timezone))
167
+ if timezone[0] == '+':
168
+ seconds -= (int(timezone[1:pos])*60+int(timezone[pos+1:]))*60
169
+ else:
170
+ seconds += (int(timezone[1:pos])*60+int(timezone[pos+1:]))*60
171
+ # Set seconds and nanos
172
+ _CheckTimestampValid(seconds, nanos)
173
+ self.seconds = int(seconds)
174
+ self.nanos = int(nanos)
175
+
176
+ def GetCurrentTime(self):
177
+ """Get the current UTC into Timestamp."""
178
+ self.FromDatetime(datetime.datetime.utcnow())
179
+
180
+ def ToNanoseconds(self):
181
+ """Converts Timestamp to nanoseconds since epoch."""
182
+ _CheckTimestampValid(self.seconds, self.nanos)
183
+ return self.seconds * _NANOS_PER_SECOND + self.nanos
184
+
185
+ def ToMicroseconds(self):
186
+ """Converts Timestamp to microseconds since epoch."""
187
+ _CheckTimestampValid(self.seconds, self.nanos)
188
+ return (self.seconds * _MICROS_PER_SECOND +
189
+ self.nanos // _NANOS_PER_MICROSECOND)
190
+
191
+ def ToMilliseconds(self):
192
+ """Converts Timestamp to milliseconds since epoch."""
193
+ _CheckTimestampValid(self.seconds, self.nanos)
194
+ return (self.seconds * _MILLIS_PER_SECOND +
195
+ self.nanos // _NANOS_PER_MILLISECOND)
196
+
197
+ def ToSeconds(self):
198
+ """Converts Timestamp to seconds since epoch."""
199
+ _CheckTimestampValid(self.seconds, self.nanos)
200
+ return self.seconds
201
+
202
+ def FromNanoseconds(self, nanos):
203
+ """Converts nanoseconds since epoch to Timestamp."""
204
+ seconds = nanos // _NANOS_PER_SECOND
205
+ nanos = nanos % _NANOS_PER_SECOND
206
+ _CheckTimestampValid(seconds, nanos)
207
+ self.seconds = seconds
208
+ self.nanos = nanos
209
+
210
+ def FromMicroseconds(self, micros):
211
+ """Converts microseconds since epoch to Timestamp."""
212
+ seconds = micros // _MICROS_PER_SECOND
213
+ nanos = (micros % _MICROS_PER_SECOND) * _NANOS_PER_MICROSECOND
214
+ _CheckTimestampValid(seconds, nanos)
215
+ self.seconds = seconds
216
+ self.nanos = nanos
217
+
218
+ def FromMilliseconds(self, millis):
219
+ """Converts milliseconds since epoch to Timestamp."""
220
+ seconds = millis // _MILLIS_PER_SECOND
221
+ nanos = (millis % _MILLIS_PER_SECOND) * _NANOS_PER_MILLISECOND
222
+ _CheckTimestampValid(seconds, nanos)
223
+ self.seconds = seconds
224
+ self.nanos = nanos
225
+
226
+ def FromSeconds(self, seconds):
227
+ """Converts seconds since epoch to Timestamp."""
228
+ _CheckTimestampValid(seconds, 0)
229
+ self.seconds = seconds
230
+ self.nanos = 0
231
+
232
+ def ToDatetime(self, tzinfo=None):
233
+ """Converts Timestamp to a datetime.
234
+
235
+ Args:
236
+ tzinfo: A datetime.tzinfo subclass; defaults to None.
237
+
238
+ Returns:
239
+ If tzinfo is None, returns a timezone-naive UTC datetime (with no timezone
240
+ information, i.e. not aware that it's UTC).
241
+
242
+ Otherwise, returns a timezone-aware datetime in the input timezone.
243
+ """
244
+ # Using datetime.fromtimestamp for this would avoid constructing an extra
245
+ # timedelta object and possibly an extra datetime. Unfortunately, that has
246
+ # the disadvantage of not handling the full precision (on all platforms, see
247
+ # https://github.com/python/cpython/issues/109849) or full range (on some
248
+ # platforms, see https://github.com/python/cpython/issues/110042) of
249
+ # datetime.
250
+ _CheckTimestampValid(self.seconds, self.nanos)
251
+ delta = datetime.timedelta(
252
+ seconds=self.seconds,
253
+ microseconds=_RoundTowardZero(self.nanos, _NANOS_PER_MICROSECOND),
254
+ )
255
+ if tzinfo is None:
256
+ return _EPOCH_DATETIME_NAIVE + delta
257
+ else:
258
+ # Note the tz conversion has to come after the timedelta arithmetic.
259
+ return (_EPOCH_DATETIME_AWARE + delta).astimezone(tzinfo)
260
+
261
+ def FromDatetime(self, dt):
262
+ """Converts datetime to Timestamp.
263
+
264
+ Args:
265
+ dt: A datetime. If it's timezone-naive, it's assumed to be in UTC.
266
+ """
267
+ # Using this guide: http://wiki.python.org/moin/WorkingWithTime
268
+ # And this conversion guide: http://docs.python.org/library/time.html
269
+
270
+ # Turn the date parameter into a tuple (struct_time) that can then be
271
+ # manipulated into a long value of seconds. During the conversion from
272
+ # struct_time to long, the source date in UTC, and so it follows that the
273
+ # correct transformation is calendar.timegm()
274
+ try:
275
+ seconds = calendar.timegm(dt.utctimetuple())
276
+ nanos = dt.microsecond * _NANOS_PER_MICROSECOND
277
+ except AttributeError as e:
278
+ raise AttributeError(
279
+ 'Fail to convert to Timestamp. Expected a datetime like '
280
+ 'object got {0} : {1}'.format(type(dt).__name__, e)
281
+ ) from e
282
+ _CheckTimestampValid(seconds, nanos)
283
+ self.seconds = seconds
284
+ self.nanos = nanos
285
+
286
+ def _internal_assign(self, dt):
287
+ self.FromDatetime(dt)
288
+
289
+ def __add__(self, value) -> datetime.datetime:
290
+ if isinstance(value, Duration):
291
+ return self.ToDatetime() + value.ToTimedelta()
292
+ return self.ToDatetime() + value
293
+
294
+ __radd__ = __add__
295
+
296
+ def __sub__(self, value) -> Union[datetime.datetime, datetime.timedelta]:
297
+ if isinstance(value, Timestamp):
298
+ return self.ToDatetime() - value.ToDatetime()
299
+ elif isinstance(value, Duration):
300
+ return self.ToDatetime() - value.ToTimedelta()
301
+ return self.ToDatetime() - value
302
+
303
+ def __rsub__(self, dt) -> datetime.timedelta:
304
+ return dt - self.ToDatetime()
305
+
306
+
307
+ def _CheckTimestampValid(seconds, nanos):
308
+ if seconds < _TIMESTAMP_SECONDS_MIN or seconds > _TIMESTAMP_SECONDS_MAX:
309
+ raise ValueError(
310
+ 'Timestamp is not valid: Seconds {0} must be in range '
311
+ '[-62135596800, 253402300799].'.format(seconds))
312
+ if nanos < 0 or nanos >= _NANOS_PER_SECOND:
313
+ raise ValueError(
314
+ 'Timestamp is not valid: Nanos {} must be in a range '
315
+ '[0, 999999].'.format(nanos))
316
+
317
+
318
+ class Duration(object):
319
+ """Class for Duration message type."""
320
+
321
+ __slots__ = ()
322
+
323
+ def ToJsonString(self):
324
+ """Converts Duration to string format.
325
+
326
+ Returns:
327
+ A string converted from self. The string format will contains
328
+ 3, 6, or 9 fractional digits depending on the precision required to
329
+ represent the exact Duration value. For example: "1s", "1.010s",
330
+ "1.000000100s", "-3.100s"
331
+ """
332
+ _CheckDurationValid(self.seconds, self.nanos)
333
+ if self.seconds < 0 or self.nanos < 0:
334
+ result = '-'
335
+ seconds = - self.seconds + int((0 - self.nanos) // 1e9)
336
+ nanos = (0 - self.nanos) % 1e9
337
+ else:
338
+ result = ''
339
+ seconds = self.seconds + int(self.nanos // 1e9)
340
+ nanos = self.nanos % 1e9
341
+ result += '%d' % seconds
342
+ if (nanos % 1e9) == 0:
343
+ # If there are 0 fractional digits, the fractional
344
+ # point '.' should be omitted when serializing.
345
+ return result + 's'
346
+ if (nanos % 1e6) == 0:
347
+ # Serialize 3 fractional digits.
348
+ return result + '.%03ds' % (nanos / 1e6)
349
+ if (nanos % 1e3) == 0:
350
+ # Serialize 6 fractional digits.
351
+ return result + '.%06ds' % (nanos / 1e3)
352
+ # Serialize 9 fractional digits.
353
+ return result + '.%09ds' % nanos
354
+
355
+ def FromJsonString(self, value):
356
+ """Converts a string to Duration.
357
+
358
+ Args:
359
+ value: A string to be converted. The string must end with 's'. Any
360
+ fractional digits (or none) are accepted as long as they fit into
361
+ precision. For example: "1s", "1.01s", "1.0000001s", "-3.100s
362
+
363
+ Raises:
364
+ ValueError: On parsing problems.
365
+ """
366
+ if not isinstance(value, str):
367
+ raise ValueError('Duration JSON value not a string: {!r}'.format(value))
368
+ if len(value) < 1 or value[-1] != 's':
369
+ raise ValueError(
370
+ 'Duration must end with letter "s": {0}.'.format(value))
371
+ try:
372
+ pos = value.find('.')
373
+ if pos == -1:
374
+ seconds = int(value[:-1])
375
+ nanos = 0
376
+ else:
377
+ seconds = int(value[:pos])
378
+ if value[0] == '-':
379
+ nanos = int(round(float('-0{0}'.format(value[pos: -1])) *1e9))
380
+ else:
381
+ nanos = int(round(float('0{0}'.format(value[pos: -1])) *1e9))
382
+ _CheckDurationValid(seconds, nanos)
383
+ self.seconds = seconds
384
+ self.nanos = nanos
385
+ except ValueError as e:
386
+ raise ValueError(
387
+ 'Couldn\'t parse duration: {0} : {1}.'.format(value, e))
388
+
389
+ def ToNanoseconds(self):
390
+ """Converts a Duration to nanoseconds."""
391
+ return self.seconds * _NANOS_PER_SECOND + self.nanos
392
+
393
+ def ToMicroseconds(self):
394
+ """Converts a Duration to microseconds."""
395
+ micros = _RoundTowardZero(self.nanos, _NANOS_PER_MICROSECOND)
396
+ return self.seconds * _MICROS_PER_SECOND + micros
397
+
398
+ def ToMilliseconds(self):
399
+ """Converts a Duration to milliseconds."""
400
+ millis = _RoundTowardZero(self.nanos, _NANOS_PER_MILLISECOND)
401
+ return self.seconds * _MILLIS_PER_SECOND + millis
402
+
403
+ def ToSeconds(self):
404
+ """Converts a Duration to seconds."""
405
+ return self.seconds
406
+
407
+ def FromNanoseconds(self, nanos):
408
+ """Converts nanoseconds to Duration."""
409
+ self._NormalizeDuration(nanos // _NANOS_PER_SECOND,
410
+ nanos % _NANOS_PER_SECOND)
411
+
412
+ def FromMicroseconds(self, micros):
413
+ """Converts microseconds to Duration."""
414
+ self._NormalizeDuration(
415
+ micros // _MICROS_PER_SECOND,
416
+ (micros % _MICROS_PER_SECOND) * _NANOS_PER_MICROSECOND)
417
+
418
+ def FromMilliseconds(self, millis):
419
+ """Converts milliseconds to Duration."""
420
+ self._NormalizeDuration(
421
+ millis // _MILLIS_PER_SECOND,
422
+ (millis % _MILLIS_PER_SECOND) * _NANOS_PER_MILLISECOND)
423
+
424
+ def FromSeconds(self, seconds):
425
+ """Converts seconds to Duration."""
426
+ self.seconds = seconds
427
+ self.nanos = 0
428
+
429
+ def ToTimedelta(self) -> datetime.timedelta:
430
+ """Converts Duration to timedelta."""
431
+ return datetime.timedelta(
432
+ seconds=self.seconds, microseconds=_RoundTowardZero(
433
+ self.nanos, _NANOS_PER_MICROSECOND))
434
+
435
+ def FromTimedelta(self, td):
436
+ """Converts timedelta to Duration."""
437
+ try:
438
+ self._NormalizeDuration(
439
+ td.seconds + td.days * _SECONDS_PER_DAY,
440
+ td.microseconds * _NANOS_PER_MICROSECOND,
441
+ )
442
+ except AttributeError as e:
443
+ raise AttributeError(
444
+ 'Fail to convert to Duration. Expected a timedelta like '
445
+ 'object got {0}: {1}'.format(type(td).__name__, e)
446
+ ) from e
447
+
448
+ def _internal_assign(self, td):
449
+ self.FromTimedelta(td)
450
+
451
+ def _NormalizeDuration(self, seconds, nanos):
452
+ """Set Duration by seconds and nanos."""
453
+ # Force nanos to be negative if the duration is negative.
454
+ if seconds < 0 and nanos > 0:
455
+ seconds += 1
456
+ nanos -= _NANOS_PER_SECOND
457
+ self.seconds = seconds
458
+ self.nanos = nanos
459
+
460
+ def __add__(self, value) -> Union[datetime.datetime, datetime.timedelta]:
461
+ if isinstance(value, Timestamp):
462
+ return self.ToTimedelta() + value.ToDatetime()
463
+ return self.ToTimedelta() + value
464
+
465
+ __radd__ = __add__
466
+
467
+ def __rsub__(self, dt) -> Union[datetime.datetime, datetime.timedelta]:
468
+ return dt - self.ToTimedelta()
469
+
470
+
471
+ def _CheckDurationValid(seconds, nanos):
472
+ if seconds < -_DURATION_SECONDS_MAX or seconds > _DURATION_SECONDS_MAX:
473
+ raise ValueError(
474
+ 'Duration is not valid: Seconds {0} must be in range '
475
+ '[-315576000000, 315576000000].'.format(seconds))
476
+ if nanos <= -_NANOS_PER_SECOND or nanos >= _NANOS_PER_SECOND:
477
+ raise ValueError(
478
+ 'Duration is not valid: Nanos {0} must be in range '
479
+ '[-999999999, 999999999].'.format(nanos))
480
+ if (nanos < 0 and seconds > 0) or (nanos > 0 and seconds < 0):
481
+ raise ValueError(
482
+ 'Duration is not valid: Sign mismatch.')
483
+
484
+
485
+ def _RoundTowardZero(value, divider):
486
+ """Truncates the remainder part after division."""
487
+ # For some languages, the sign of the remainder is implementation
488
+ # dependent if any of the operands is negative. Here we enforce
489
+ # "rounded toward zero" semantics. For example, for (-5) / 2 an
490
+ # implementation may give -3 as the result with the remainder being
491
+ # 1. This function ensures we always return -2 (closer to zero).
492
+ result = value // divider
493
+ remainder = value % divider
494
+ if result < 0 and remainder > 0:
495
+ return result + 1
496
+ else:
497
+ return result
498
+
499
+
500
+ def _SetStructValue(struct_value, value):
501
+ if value is None:
502
+ struct_value.null_value = 0
503
+ elif isinstance(value, bool):
504
+ # Note: this check must come before the number check because in Python
505
+ # True and False are also considered numbers.
506
+ struct_value.bool_value = value
507
+ elif isinstance(value, str):
508
+ struct_value.string_value = value
509
+ elif isinstance(value, (int, float)):
510
+ struct_value.number_value = value
511
+ elif isinstance(value, (dict, Struct)):
512
+ struct_value.struct_value.Clear()
513
+ struct_value.struct_value.update(value)
514
+ elif isinstance(value, (list, tuple, ListValue)):
515
+ struct_value.list_value.Clear()
516
+ struct_value.list_value.extend(value)
517
+ else:
518
+ raise ValueError('Unexpected type')
519
+
520
+
521
+ def _GetStructValue(struct_value):
522
+ which = struct_value.WhichOneof('kind')
523
+ if which == 'struct_value':
524
+ return struct_value.struct_value
525
+ elif which == 'null_value':
526
+ return None
527
+ elif which == 'number_value':
528
+ return struct_value.number_value
529
+ elif which == 'string_value':
530
+ return struct_value.string_value
531
+ elif which == 'bool_value':
532
+ return struct_value.bool_value
533
+ elif which == 'list_value':
534
+ return struct_value.list_value
535
+ elif which is None:
536
+ raise ValueError('Value not set')
537
+
538
+
539
+ class Struct(object):
540
+ """Class for Struct message type."""
541
+
542
+ __slots__ = ()
543
+
544
+ def __getitem__(self, key):
545
+ return _GetStructValue(self.fields[key])
546
+
547
+ def __setitem__(self, key, value):
548
+ _SetStructValue(self.fields[key], value)
549
+
550
+ def __delitem__(self, key):
551
+ del self.fields[key]
552
+
553
+ def __len__(self):
554
+ return len(self.fields)
555
+
556
+ def __iter__(self):
557
+ return iter(self.fields)
558
+
559
+ def _internal_assign(self, dictionary):
560
+ self.Clear()
561
+ self.update(dictionary)
562
+
563
+ def _internal_compare(self, other):
564
+ size = len(self)
565
+ if size != len(other):
566
+ return False
567
+ for key, value in self.items():
568
+ if key not in other:
569
+ return False
570
+ if isinstance(other[key], (dict, list)):
571
+ if not value._internal_compare(other[key]):
572
+ return False
573
+ elif value != other[key]:
574
+ return False
575
+ return True
576
+
577
+ def keys(self): # pylint: disable=invalid-name
578
+ return self.fields.keys()
579
+
580
+ def values(self): # pylint: disable=invalid-name
581
+ return [self[key] for key in self]
582
+
583
+ def items(self): # pylint: disable=invalid-name
584
+ return [(key, self[key]) for key in self]
585
+
586
+ def get_or_create_list(self, key):
587
+ """Returns a list for this key, creating if it didn't exist already."""
588
+ if not self.fields[key].HasField('list_value'):
589
+ # Clear will mark list_value modified which will indeed create a list.
590
+ self.fields[key].list_value.Clear()
591
+ return self.fields[key].list_value
592
+
593
+ def get_or_create_struct(self, key):
594
+ """Returns a struct for this key, creating if it didn't exist already."""
595
+ if not self.fields[key].HasField('struct_value'):
596
+ # Clear will mark struct_value modified which will indeed create a struct.
597
+ self.fields[key].struct_value.Clear()
598
+ return self.fields[key].struct_value
599
+
600
+ def update(self, dictionary): # pylint: disable=invalid-name
601
+ for key, value in dictionary.items():
602
+ _SetStructValue(self.fields[key], value)
603
+
604
+ collections.abc.MutableMapping.register(Struct)
605
+
606
+
607
+ class ListValue(object):
608
+ """Class for ListValue message type."""
609
+
610
+ __slots__ = ()
611
+
612
+ def __len__(self):
613
+ return len(self.values)
614
+
615
+ def append(self, value):
616
+ _SetStructValue(self.values.add(), value)
617
+
618
+ def extend(self, elem_seq):
619
+ for value in elem_seq:
620
+ self.append(value)
621
+
622
+ def __getitem__(self, index):
623
+ """Retrieves item by the specified index."""
624
+ return _GetStructValue(self.values.__getitem__(index))
625
+
626
+ def __setitem__(self, index, value):
627
+ _SetStructValue(self.values.__getitem__(index), value)
628
+
629
+ def __delitem__(self, key):
630
+ del self.values[key]
631
+
632
+ def _internal_assign(self, elem_seq):
633
+ self.Clear()
634
+ self.extend(elem_seq)
635
+
636
+ def _internal_compare(self, other):
637
+ size = len(self)
638
+ if size != len(other):
639
+ return False
640
+ for i in range(size):
641
+ if isinstance(other[i], (dict, list)):
642
+ if not self[i]._internal_compare(other[i]):
643
+ return False
644
+ elif self[i] != other[i]:
645
+ return False
646
+ return True
647
+
648
+ def items(self):
649
+ for i in range(len(self)):
650
+ yield self[i]
651
+
652
+ def add_struct(self):
653
+ """Appends and returns a struct value as the next value in the list."""
654
+ struct_value = self.values.add().struct_value
655
+ # Clear will mark struct_value modified which will indeed create a struct.
656
+ struct_value.Clear()
657
+ return struct_value
658
+
659
+ def add_list(self):
660
+ """Appends and returns a list value as the next value in the list."""
661
+ list_value = self.values.add().list_value
662
+ # Clear will mark list_value modified which will indeed create a list.
663
+ list_value.Clear()
664
+ return list_value
665
+
666
+ collections.abc.MutableSequence.register(ListValue)
667
+
668
+
669
+ # LINT.IfChange(wktbases)
670
+ WKTBASES = {
671
+ 'google.protobuf.Any': Any,
672
+ 'google.protobuf.Duration': Duration,
673
+ 'google.protobuf.FieldMask': FieldMask,
674
+ 'google.protobuf.ListValue': ListValue,
675
+ 'google.protobuf.Struct': Struct,
676
+ 'google.protobuf.Timestamp': Timestamp,
677
+ }
678
+ # LINT.ThenChange(//depot/google.protobuf/compiler/python/pyi_generator.cc:wktbases)