Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .venv/lib/python3.11/site-packages/google/generativeai/__pycache__/files.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/google/generativeai/__pycache__/generative_models.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/google/generativeai/__pycache__/models.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/google/generativeai/__pycache__/responder.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/google/generativeai/__pycache__/string_utils.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/google/generativeai/__pycache__/text.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/google/generativeai/types/__pycache__/answer_types.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/google/generativeai/types/__pycache__/caching_types.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/google/generativeai/types/__pycache__/content_types.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/google/generativeai/types/__pycache__/discuss_types.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/google/generativeai/types/__pycache__/helper_types.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/google/generativeai/types/__pycache__/model_types.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/google/generativeai/types/__pycache__/retriever_types.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/google/generativeai/types/__pycache__/safety_types.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/google/generativeai/types/answer_types.py +58 -0
- .venv/lib/python3.11/site-packages/google/generativeai/types/caching_types.py +83 -0
- .venv/lib/python3.11/site-packages/google/generativeai/types/content_types.py +985 -0
- .venv/lib/python3.11/site-packages/google/generativeai/types/discuss_types.py +208 -0
- .venv/lib/python3.11/site-packages/google/generativeai/types/file_types.py +143 -0
- .venv/lib/python3.11/site-packages/google/generativeai/types/generation_types.py +759 -0
- .venv/lib/python3.11/site-packages/google/generativeai/types/image_types/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/google/generativeai/types/image_types/__pycache__/_image_types.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/google/generativeai/types/image_types/_image_types.py +440 -0
- .venv/lib/python3.11/site-packages/google/generativeai/types/model_types.py +390 -0
- .venv/lib/python3.11/site-packages/google/generativeai/types/palm_safety_types.py +286 -0
- .venv/lib/python3.11/site-packages/google/generativeai/types/safety_types.py +303 -0
- .venv/lib/python3.11/site-packages/google/generativeai/types/text_types.py +32 -0
- .venv/lib/python3.11/site-packages/google/logging/type/__pycache__/http_request_pb2.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/google/logging/type/__pycache__/log_severity_pb2.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/google/logging/type/http_request.proto +95 -0
- .venv/lib/python3.11/site-packages/google/logging/type/log_severity.proto +71 -0
- .venv/lib/python3.11/site-packages/google/logging/type/log_severity_pb2.py +44 -0
- .venv/lib/python3.11/site-packages/google/protobuf/__init__.py +10 -0
- .venv/lib/python3.11/site-packages/google/protobuf/any.py +39 -0
- .venv/lib/python3.11/site-packages/google/protobuf/any_pb2.py +37 -0
- .venv/lib/python3.11/site-packages/google/protobuf/api_pb2.py +43 -0
- .venv/lib/python3.11/site-packages/google/protobuf/descriptor.py +1511 -0
- .venv/lib/python3.11/site-packages/google/protobuf/descriptor_database.py +154 -0
- .venv/lib/python3.11/site-packages/google/protobuf/descriptor_pb2.py +0 -0
- .venv/lib/python3.11/site-packages/google/protobuf/descriptor_pool.py +1355 -0
- .venv/lib/python3.11/site-packages/google/protobuf/duration.py +100 -0
- .venv/lib/python3.11/site-packages/google/protobuf/duration_pb2.py +37 -0
- .venv/lib/python3.11/site-packages/google/protobuf/empty_pb2.py +37 -0
- .venv/lib/python3.11/site-packages/google/protobuf/field_mask_pb2.py +37 -0
- .venv/lib/python3.11/site-packages/google/protobuf/internal/_parameterized.py +420 -0
- .venv/lib/python3.11/site-packages/google/protobuf/internal/containers.py +677 -0
- .venv/lib/python3.11/site-packages/google/protobuf/internal/encoder.py +806 -0
- .venv/lib/python3.11/site-packages/google/protobuf/internal/python_edition_defaults.py +5 -0
- .venv/lib/python3.11/site-packages/google/protobuf/internal/testing_refleaks.py +119 -0
- .venv/lib/python3.11/site-packages/google/protobuf/internal/well_known_types.py +678 -0
.venv/lib/python3.11/site-packages/google/generativeai/__pycache__/files.cpython-311.pyc
ADDED
|
Binary file (4.88 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/google/generativeai/__pycache__/generative_models.cpython-311.pyc
ADDED
|
Binary file (35.1 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/google/generativeai/__pycache__/models.cpython-311.pyc
ADDED
|
Binary file (17.4 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/google/generativeai/__pycache__/responder.cpython-311.pyc
ADDED
|
Binary file (24.2 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/google/generativeai/__pycache__/string_utils.cpython-311.pyc
ADDED
|
Binary file (3.78 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/google/generativeai/__pycache__/text.cpython-311.pyc
ADDED
|
Binary file (14.6 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/google/generativeai/types/__pycache__/answer_types.cpython-311.pyc
ADDED
|
Binary file (2.09 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/google/generativeai/types/__pycache__/caching_types.cpython-311.pyc
ADDED
|
Binary file (2.89 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/google/generativeai/types/__pycache__/content_types.cpython-311.pyc
ADDED
|
Binary file (43.7 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/google/generativeai/types/__pycache__/discuss_types.cpython-311.pyc
ADDED
|
Binary file (8.6 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/google/generativeai/types/__pycache__/helper_types.cpython-311.pyc
ADDED
|
Binary file (3.42 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/google/generativeai/types/__pycache__/model_types.cpython-311.pyc
ADDED
|
Binary file (18.4 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/google/generativeai/types/__pycache__/retriever_types.cpython-311.pyc
ADDED
|
Binary file (66.8 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/google/generativeai/types/__pycache__/safety_types.cpython-311.pyc
ADDED
|
Binary file (13.6 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/google/generativeai/types/answer_types.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Copyright 2023 Google LLC
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
from __future__ import annotations
|
| 16 |
+
|
| 17 |
+
from typing import Union
|
| 18 |
+
|
| 19 |
+
from google.generativeai import protos
|
| 20 |
+
|
| 21 |
+
__all__ = ["Answer"]
|
| 22 |
+
|
| 23 |
+
FinishReason = protos.Candidate.FinishReason
|
| 24 |
+
|
| 25 |
+
FinishReasonOptions = Union[int, str, FinishReason]
|
| 26 |
+
|
| 27 |
+
_FINISH_REASONS: dict[FinishReasonOptions, FinishReason] = {
|
| 28 |
+
FinishReason.FINISH_REASON_UNSPECIFIED: FinishReason.FINISH_REASON_UNSPECIFIED,
|
| 29 |
+
0: FinishReason.FINISH_REASON_UNSPECIFIED,
|
| 30 |
+
"finish_reason_unspecified": FinishReason.FINISH_REASON_UNSPECIFIED,
|
| 31 |
+
"unspecified": FinishReason.FINISH_REASON_UNSPECIFIED,
|
| 32 |
+
FinishReason.STOP: FinishReason.STOP,
|
| 33 |
+
1: FinishReason.STOP,
|
| 34 |
+
"finish_reason_stop": FinishReason.STOP,
|
| 35 |
+
"stop": FinishReason.STOP,
|
| 36 |
+
FinishReason.MAX_TOKENS: FinishReason.MAX_TOKENS,
|
| 37 |
+
2: FinishReason.MAX_TOKENS,
|
| 38 |
+
"finish_reason_max_tokens": FinishReason.MAX_TOKENS,
|
| 39 |
+
"max_tokens": FinishReason.MAX_TOKENS,
|
| 40 |
+
FinishReason.SAFETY: FinishReason.SAFETY,
|
| 41 |
+
3: FinishReason.SAFETY,
|
| 42 |
+
"finish_reason_safety": FinishReason.SAFETY,
|
| 43 |
+
"safety": FinishReason.SAFETY,
|
| 44 |
+
FinishReason.RECITATION: FinishReason.RECITATION,
|
| 45 |
+
4: FinishReason.RECITATION,
|
| 46 |
+
"finish_reason_recitation": FinishReason.RECITATION,
|
| 47 |
+
"recitation": FinishReason.RECITATION,
|
| 48 |
+
FinishReason.OTHER: FinishReason.OTHER,
|
| 49 |
+
5: FinishReason.OTHER,
|
| 50 |
+
"finish_reason_other": FinishReason.OTHER,
|
| 51 |
+
"other": FinishReason.OTHER,
|
| 52 |
+
}
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def to_finish_reason(x: FinishReasonOptions) -> FinishReason:
|
| 56 |
+
if isinstance(x, str):
|
| 57 |
+
x = x.lower()
|
| 58 |
+
return _FINISH_REASONS[x]
|
.venv/lib/python3.11/site-packages/google/generativeai/types/caching_types.py
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Copyright 2024 Google LLC
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
from __future__ import annotations
|
| 16 |
+
|
| 17 |
+
import datetime
|
| 18 |
+
from typing import Union
|
| 19 |
+
from typing_extensions import TypedDict
|
| 20 |
+
|
| 21 |
+
__all__ = [
|
| 22 |
+
"ExpireTime",
|
| 23 |
+
"TTL",
|
| 24 |
+
"TTLTypes",
|
| 25 |
+
"ExpireTimeTypes",
|
| 26 |
+
]
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class TTL(TypedDict):
|
| 30 |
+
# Represents datetime.datetime.now() + desired ttl
|
| 31 |
+
seconds: int
|
| 32 |
+
nanos: int
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class ExpireTime(TypedDict):
|
| 36 |
+
# Represents seconds of UTC time since Unix epoch
|
| 37 |
+
seconds: int
|
| 38 |
+
nanos: int
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
TTLTypes = Union[TTL, int, datetime.timedelta]
|
| 42 |
+
ExpireTimeTypes = Union[ExpireTime, int, datetime.datetime]
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def to_optional_ttl(ttl: TTLTypes | None) -> TTL | None:
|
| 46 |
+
if ttl is None:
|
| 47 |
+
return None
|
| 48 |
+
elif isinstance(ttl, datetime.timedelta):
|
| 49 |
+
return {
|
| 50 |
+
"seconds": int(ttl.total_seconds()),
|
| 51 |
+
"nanos": int(ttl.microseconds * 1000),
|
| 52 |
+
}
|
| 53 |
+
elif isinstance(ttl, dict):
|
| 54 |
+
return ttl
|
| 55 |
+
elif isinstance(ttl, int):
|
| 56 |
+
return {"seconds": ttl, "nanos": 0}
|
| 57 |
+
else:
|
| 58 |
+
raise TypeError(
|
| 59 |
+
f"Could not convert input to `ttl` \n'" f" type: {type(ttl)}\n",
|
| 60 |
+
ttl,
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def to_optional_expire_time(expire_time: ExpireTimeTypes | None) -> ExpireTime | None:
|
| 65 |
+
if expire_time is None:
|
| 66 |
+
return expire_time
|
| 67 |
+
elif isinstance(expire_time, datetime.datetime):
|
| 68 |
+
timestamp = expire_time.timestamp()
|
| 69 |
+
seconds = int(timestamp)
|
| 70 |
+
nanos = int((seconds % 1) * 1000)
|
| 71 |
+
return {
|
| 72 |
+
"seconds": seconds,
|
| 73 |
+
"nanos": nanos,
|
| 74 |
+
}
|
| 75 |
+
elif isinstance(expire_time, dict):
|
| 76 |
+
return expire_time
|
| 77 |
+
elif isinstance(expire_time, int):
|
| 78 |
+
return {"seconds": expire_time, "nanos": 0}
|
| 79 |
+
else:
|
| 80 |
+
raise TypeError(
|
| 81 |
+
f"Could not convert input to `expire_time` \n'" f" type: {type(expire_time)}\n",
|
| 82 |
+
expire_time,
|
| 83 |
+
)
|
.venv/lib/python3.11/site-packages/google/generativeai/types/content_types.py
ADDED
|
@@ -0,0 +1,985 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Google LLC
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
from __future__ import annotations
|
| 17 |
+
|
| 18 |
+
from collections.abc import Iterable, Mapping, Sequence
|
| 19 |
+
import io
|
| 20 |
+
import inspect
|
| 21 |
+
import mimetypes
|
| 22 |
+
import pathlib
|
| 23 |
+
import typing
|
| 24 |
+
from typing import Any, Callable, Union
|
| 25 |
+
from typing_extensions import TypedDict
|
| 26 |
+
|
| 27 |
+
import pydantic
|
| 28 |
+
|
| 29 |
+
from google.generativeai.types import file_types
|
| 30 |
+
from google.generativeai import protos
|
| 31 |
+
|
| 32 |
+
if typing.TYPE_CHECKING:
|
| 33 |
+
import PIL.Image
|
| 34 |
+
import PIL.ImageFile
|
| 35 |
+
import IPython.display
|
| 36 |
+
|
| 37 |
+
IMAGE_TYPES = (PIL.Image.Image, IPython.display.Image)
|
| 38 |
+
else:
|
| 39 |
+
IMAGE_TYPES = ()
|
| 40 |
+
try:
|
| 41 |
+
import PIL.Image
|
| 42 |
+
import PIL.ImageFile
|
| 43 |
+
|
| 44 |
+
IMAGE_TYPES = IMAGE_TYPES + (PIL.Image.Image,)
|
| 45 |
+
except ImportError:
|
| 46 |
+
PIL = None
|
| 47 |
+
|
| 48 |
+
try:
|
| 49 |
+
import IPython.display
|
| 50 |
+
|
| 51 |
+
IMAGE_TYPES = IMAGE_TYPES + (IPython.display.Image,)
|
| 52 |
+
except ImportError:
|
| 53 |
+
IPython = None
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
__all__ = [
|
| 57 |
+
"BlobDict",
|
| 58 |
+
"BlobType",
|
| 59 |
+
"PartDict",
|
| 60 |
+
"PartType",
|
| 61 |
+
"ContentDict",
|
| 62 |
+
"ContentType",
|
| 63 |
+
"StrictContentType",
|
| 64 |
+
"ContentsType",
|
| 65 |
+
"FunctionDeclaration",
|
| 66 |
+
"CallableFunctionDeclaration",
|
| 67 |
+
"FunctionDeclarationType",
|
| 68 |
+
"Tool",
|
| 69 |
+
"ToolDict",
|
| 70 |
+
"ToolsType",
|
| 71 |
+
"FunctionLibrary",
|
| 72 |
+
"FunctionLibraryType",
|
| 73 |
+
]
|
| 74 |
+
|
| 75 |
+
Mode = protos.DynamicRetrievalConfig.Mode
|
| 76 |
+
|
| 77 |
+
ModeOptions = Union[int, str, Mode]
|
| 78 |
+
|
| 79 |
+
_MODE: dict[ModeOptions, Mode] = {
|
| 80 |
+
Mode.MODE_UNSPECIFIED: Mode.MODE_UNSPECIFIED,
|
| 81 |
+
0: Mode.MODE_UNSPECIFIED,
|
| 82 |
+
"mode_unspecified": Mode.MODE_UNSPECIFIED,
|
| 83 |
+
"unspecified": Mode.MODE_UNSPECIFIED,
|
| 84 |
+
Mode.MODE_DYNAMIC: Mode.MODE_DYNAMIC,
|
| 85 |
+
1: Mode.MODE_DYNAMIC,
|
| 86 |
+
"mode_dynamic": Mode.MODE_DYNAMIC,
|
| 87 |
+
"dynamic": Mode.MODE_DYNAMIC,
|
| 88 |
+
}
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def to_mode(x: ModeOptions) -> Mode:
|
| 92 |
+
if isinstance(x, str):
|
| 93 |
+
x = x.lower()
|
| 94 |
+
return _MODE[x]
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def _pil_to_blob(image: PIL.Image.Image) -> protos.Blob:
|
| 98 |
+
# If the image is a local file, return a file-based blob without any modification.
|
| 99 |
+
# Otherwise, return a lossless WebP blob (same quality with optimized size).
|
| 100 |
+
def file_blob(image: PIL.Image.Image) -> protos.Blob | None:
|
| 101 |
+
if not isinstance(image, PIL.ImageFile.ImageFile) or image.filename is None:
|
| 102 |
+
return None
|
| 103 |
+
filename = str(image.filename)
|
| 104 |
+
if not pathlib.Path(filename).is_file():
|
| 105 |
+
return None
|
| 106 |
+
|
| 107 |
+
mime_type = image.get_format_mimetype()
|
| 108 |
+
image_bytes = pathlib.Path(filename).read_bytes()
|
| 109 |
+
|
| 110 |
+
return protos.Blob(mime_type=mime_type, data=image_bytes)
|
| 111 |
+
|
| 112 |
+
def webp_blob(image: PIL.Image.Image) -> protos.Blob:
|
| 113 |
+
# Reference: https://pillow.readthedocs.io/en/stable/handbook/image-file-formats.html#webp
|
| 114 |
+
image_io = io.BytesIO()
|
| 115 |
+
image.save(image_io, format="webp", lossless=True)
|
| 116 |
+
image_io.seek(0)
|
| 117 |
+
|
| 118 |
+
mime_type = "image/webp"
|
| 119 |
+
image_bytes = image_io.read()
|
| 120 |
+
|
| 121 |
+
return protos.Blob(mime_type=mime_type, data=image_bytes)
|
| 122 |
+
|
| 123 |
+
return file_blob(image) or webp_blob(image)
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def image_to_blob(image) -> protos.Blob:
|
| 127 |
+
if PIL is not None:
|
| 128 |
+
if isinstance(image, PIL.Image.Image):
|
| 129 |
+
return _pil_to_blob(image)
|
| 130 |
+
|
| 131 |
+
if IPython is not None:
|
| 132 |
+
if isinstance(image, IPython.display.Image):
|
| 133 |
+
name = image.filename
|
| 134 |
+
if name is None:
|
| 135 |
+
raise ValueError(
|
| 136 |
+
"Conversion failed. The `IPython.display.Image` can only be converted if "
|
| 137 |
+
"it is constructed from a local file. Please ensure you are using the format: Image(filename='...')."
|
| 138 |
+
)
|
| 139 |
+
mime_type, _ = mimetypes.guess_type(name)
|
| 140 |
+
if mime_type is None:
|
| 141 |
+
mime_type = "image/unknown"
|
| 142 |
+
|
| 143 |
+
return protos.Blob(mime_type=mime_type, data=image.data)
|
| 144 |
+
|
| 145 |
+
raise TypeError(
|
| 146 |
+
"Image conversion failed. The input was expected to be of type `Image` "
|
| 147 |
+
"(either `PIL.Image.Image` or `IPython.display.Image`).\n"
|
| 148 |
+
f"However, received an object of type: {type(image)}.\n"
|
| 149 |
+
f"Object Value: {image}"
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
class BlobDict(TypedDict):
|
| 154 |
+
mime_type: str
|
| 155 |
+
data: bytes
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
def _convert_dict(d: Mapping) -> protos.Content | protos.Part | protos.Blob:
|
| 159 |
+
if is_content_dict(d):
|
| 160 |
+
content = dict(d)
|
| 161 |
+
if isinstance(parts := content["parts"], str):
|
| 162 |
+
content["parts"] = [parts]
|
| 163 |
+
content["parts"] = [to_part(part) for part in content["parts"]]
|
| 164 |
+
return protos.Content(content)
|
| 165 |
+
elif is_part_dict(d):
|
| 166 |
+
part = dict(d)
|
| 167 |
+
if "inline_data" in part:
|
| 168 |
+
part["inline_data"] = to_blob(part["inline_data"])
|
| 169 |
+
if "file_data" in part:
|
| 170 |
+
part["file_data"] = file_types.to_file_data(part["file_data"])
|
| 171 |
+
return protos.Part(part)
|
| 172 |
+
elif is_blob_dict(d):
|
| 173 |
+
blob = d
|
| 174 |
+
return protos.Blob(blob)
|
| 175 |
+
else:
|
| 176 |
+
raise KeyError(
|
| 177 |
+
"Unable to determine the intended type of the `dict`. "
|
| 178 |
+
"For `Content`, a 'parts' key is expected. "
|
| 179 |
+
"For `Part`, either an 'inline_data' or a 'text' key is expected. "
|
| 180 |
+
"For `Blob`, both 'mime_type' and 'data' keys are expected. "
|
| 181 |
+
f"However, the provided dictionary has the following keys: {list(d.keys())}"
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
def is_blob_dict(d):
|
| 186 |
+
return "mime_type" in d and "data" in d
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
if typing.TYPE_CHECKING:
|
| 190 |
+
BlobType = Union[
|
| 191 |
+
protos.Blob, BlobDict, PIL.Image.Image, IPython.display.Image
|
| 192 |
+
] # Any for the images
|
| 193 |
+
else:
|
| 194 |
+
BlobType = Union[protos.Blob, BlobDict, Any]
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
def to_blob(blob: BlobType) -> protos.Blob:
|
| 198 |
+
if isinstance(blob, Mapping):
|
| 199 |
+
blob = _convert_dict(blob)
|
| 200 |
+
|
| 201 |
+
if isinstance(blob, protos.Blob):
|
| 202 |
+
return blob
|
| 203 |
+
elif isinstance(blob, IMAGE_TYPES):
|
| 204 |
+
return image_to_blob(blob)
|
| 205 |
+
else:
|
| 206 |
+
if isinstance(blob, Mapping):
|
| 207 |
+
raise KeyError(
|
| 208 |
+
"Could not recognize the intended type of the `dict`\n" "A content should have "
|
| 209 |
+
)
|
| 210 |
+
raise TypeError(
|
| 211 |
+
"Could not create `Blob`, expected `Blob`, `dict` or an `Image` type"
|
| 212 |
+
"(`PIL.Image.Image` or `IPython.display.Image`).\n"
|
| 213 |
+
f"Got a: {type(blob)}\n"
|
| 214 |
+
f"Value: {blob}"
|
| 215 |
+
)
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
class PartDict(TypedDict):
|
| 219 |
+
text: str
|
| 220 |
+
inline_data: BlobType
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
# When you need a `Part` accept a part object, part-dict, blob or string
|
| 224 |
+
PartType = Union[
|
| 225 |
+
protos.Part,
|
| 226 |
+
PartDict,
|
| 227 |
+
BlobType,
|
| 228 |
+
str,
|
| 229 |
+
protos.FunctionCall,
|
| 230 |
+
protos.FunctionResponse,
|
| 231 |
+
file_types.FileDataType,
|
| 232 |
+
]
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
def is_part_dict(d):
|
| 236 |
+
keys = list(d.keys())
|
| 237 |
+
if len(keys) != 1:
|
| 238 |
+
return False
|
| 239 |
+
|
| 240 |
+
key = keys[0]
|
| 241 |
+
|
| 242 |
+
return key in ["text", "inline_data", "function_call", "function_response", "file_data"]
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
def to_part(part: PartType):
|
| 246 |
+
if isinstance(part, Mapping):
|
| 247 |
+
part = _convert_dict(part)
|
| 248 |
+
|
| 249 |
+
if isinstance(part, protos.Part):
|
| 250 |
+
return part
|
| 251 |
+
elif isinstance(part, str):
|
| 252 |
+
return protos.Part(text=part)
|
| 253 |
+
elif isinstance(part, protos.FileData):
|
| 254 |
+
return protos.Part(file_data=part)
|
| 255 |
+
elif isinstance(part, (protos.File, file_types.File)):
|
| 256 |
+
return protos.Part(file_data=file_types.to_file_data(part))
|
| 257 |
+
elif isinstance(part, protos.FunctionCall):
|
| 258 |
+
return protos.Part(function_call=part)
|
| 259 |
+
elif isinstance(part, protos.FunctionResponse):
|
| 260 |
+
return protos.Part(function_response=part)
|
| 261 |
+
|
| 262 |
+
else:
|
| 263 |
+
# Maybe it can be turned into a blob?
|
| 264 |
+
return protos.Part(inline_data=to_blob(part))
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
class ContentDict(TypedDict):
|
| 268 |
+
parts: list[PartType]
|
| 269 |
+
role: str
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
def is_content_dict(d):
|
| 273 |
+
return "parts" in d
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
# When you need a message accept a `Content` object or dict, a list of parts,
|
| 277 |
+
# or a single part
|
| 278 |
+
ContentType = Union[protos.Content, ContentDict, Iterable[PartType], PartType]
|
| 279 |
+
|
| 280 |
+
# For generate_content, we're not guessing roles for [[parts],[parts],[parts]] yet.
|
| 281 |
+
StrictContentType = Union[protos.Content, ContentDict]
|
| 282 |
+
|
| 283 |
+
|
| 284 |
+
def to_content(content: ContentType):
|
| 285 |
+
if not content:
|
| 286 |
+
raise ValueError(
|
| 287 |
+
"Invalid input: 'content' argument must not be empty. Please provide a non-empty value."
|
| 288 |
+
)
|
| 289 |
+
|
| 290 |
+
if isinstance(content, Mapping):
|
| 291 |
+
content = _convert_dict(content)
|
| 292 |
+
|
| 293 |
+
if isinstance(content, protos.Content):
|
| 294 |
+
return content
|
| 295 |
+
elif isinstance(content, Iterable) and not isinstance(content, str):
|
| 296 |
+
return protos.Content(parts=[to_part(part) for part in content])
|
| 297 |
+
else:
|
| 298 |
+
# Maybe this is a Part?
|
| 299 |
+
return protos.Content(parts=[to_part(content)])
|
| 300 |
+
|
| 301 |
+
|
| 302 |
+
def strict_to_content(content: StrictContentType):
|
| 303 |
+
if isinstance(content, Mapping):
|
| 304 |
+
content = _convert_dict(content)
|
| 305 |
+
|
| 306 |
+
if isinstance(content, protos.Content):
|
| 307 |
+
return content
|
| 308 |
+
else:
|
| 309 |
+
raise TypeError(
|
| 310 |
+
"Invalid input type. Expected a `protos.Content` or a `dict` with a 'parts' key.\n"
|
| 311 |
+
f"However, received an object of type: {type(content)}.\n"
|
| 312 |
+
f"Object Value: {content}"
|
| 313 |
+
)
|
| 314 |
+
|
| 315 |
+
|
| 316 |
+
ContentsType = Union[ContentType, Iterable[StrictContentType], None]
|
| 317 |
+
|
| 318 |
+
|
| 319 |
+
def to_contents(contents: ContentsType) -> list[protos.Content]:
|
| 320 |
+
if contents is None:
|
| 321 |
+
return []
|
| 322 |
+
|
| 323 |
+
if isinstance(contents, Iterable) and not isinstance(contents, (str, Mapping)):
|
| 324 |
+
try:
|
| 325 |
+
# strict_to_content so [[parts], [parts]] doesn't assume roles.
|
| 326 |
+
contents = [strict_to_content(c) for c in contents]
|
| 327 |
+
return contents
|
| 328 |
+
except TypeError:
|
| 329 |
+
# If you get a TypeError here it's probably because that was a list
|
| 330 |
+
# of parts, not a list of contents, so fall back to `to_content`.
|
| 331 |
+
pass
|
| 332 |
+
|
| 333 |
+
contents = [to_content(contents)]
|
| 334 |
+
return contents
|
| 335 |
+
|
| 336 |
+
|
| 337 |
+
def _schema_for_class(cls: TypedDict) -> dict[str, Any]:
|
| 338 |
+
schema = _build_schema("dummy", {"dummy": (cls, pydantic.Field())})
|
| 339 |
+
return schema["properties"]["dummy"]
|
| 340 |
+
|
| 341 |
+
|
| 342 |
+
def _schema_for_function(
|
| 343 |
+
f: Callable[..., Any],
|
| 344 |
+
*,
|
| 345 |
+
descriptions: Mapping[str, str] | None = None,
|
| 346 |
+
required: Sequence[str] | None = None,
|
| 347 |
+
) -> dict[str, Any]:
|
| 348 |
+
"""Generates the OpenAPI Schema for a python function.
|
| 349 |
+
|
| 350 |
+
Args:
|
| 351 |
+
f: The function to generate an OpenAPI Schema for.
|
| 352 |
+
descriptions: Optional. A `{name: description}` mapping for annotating input
|
| 353 |
+
arguments of the function with user-provided descriptions. It
|
| 354 |
+
defaults to an empty dictionary (i.e. there will not be any
|
| 355 |
+
description for any of the inputs).
|
| 356 |
+
required: Optional. For the user to specify the set of required arguments in
|
| 357 |
+
function calls to `f`. If unspecified, it will be automatically
|
| 358 |
+
inferred from `f`.
|
| 359 |
+
|
| 360 |
+
Returns:
|
| 361 |
+
dict[str, Any]: The OpenAPI Schema for the function `f` in JSON format.
|
| 362 |
+
"""
|
| 363 |
+
if descriptions is None:
|
| 364 |
+
descriptions = {}
|
| 365 |
+
defaults = dict(inspect.signature(f).parameters)
|
| 366 |
+
|
| 367 |
+
fields_dict = {}
|
| 368 |
+
for name, param in defaults.items():
|
| 369 |
+
if param.kind in (
|
| 370 |
+
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
| 371 |
+
inspect.Parameter.KEYWORD_ONLY,
|
| 372 |
+
inspect.Parameter.POSITIONAL_ONLY,
|
| 373 |
+
):
|
| 374 |
+
# We do not support default values for now.
|
| 375 |
+
# default=(
|
| 376 |
+
# param.default if param.default != inspect.Parameter.empty
|
| 377 |
+
# else None
|
| 378 |
+
# ),
|
| 379 |
+
field = pydantic.Field(
|
| 380 |
+
# We support user-provided descriptions.
|
| 381 |
+
description=descriptions.get(name, None)
|
| 382 |
+
)
|
| 383 |
+
|
| 384 |
+
# 1. We infer the argument type here: use Any rather than None so
|
| 385 |
+
# it will not try to auto-infer the type based on the default value.
|
| 386 |
+
if param.annotation != inspect.Parameter.empty:
|
| 387 |
+
fields_dict[name] = param.annotation, field
|
| 388 |
+
else:
|
| 389 |
+
fields_dict[name] = Any, field
|
| 390 |
+
|
| 391 |
+
parameters = _build_schema(f.__name__, fields_dict)
|
| 392 |
+
|
| 393 |
+
# 6. Annotate required fields.
|
| 394 |
+
if required is not None:
|
| 395 |
+
# We use the user-provided "required" fields if specified.
|
| 396 |
+
parameters["required"] = required
|
| 397 |
+
else:
|
| 398 |
+
# Otherwise we infer it from the function signature.
|
| 399 |
+
parameters["required"] = [
|
| 400 |
+
k
|
| 401 |
+
for k in defaults
|
| 402 |
+
if (
|
| 403 |
+
defaults[k].default == inspect.Parameter.empty
|
| 404 |
+
and defaults[k].kind
|
| 405 |
+
in (
|
| 406 |
+
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
| 407 |
+
inspect.Parameter.KEYWORD_ONLY,
|
| 408 |
+
inspect.Parameter.POSITIONAL_ONLY,
|
| 409 |
+
)
|
| 410 |
+
)
|
| 411 |
+
]
|
| 412 |
+
schema = dict(name=f.__name__, description=f.__doc__)
|
| 413 |
+
if parameters["properties"]:
|
| 414 |
+
schema["parameters"] = parameters
|
| 415 |
+
|
| 416 |
+
return schema
|
| 417 |
+
|
| 418 |
+
|
| 419 |
+
def _build_schema(fname, fields_dict):
|
| 420 |
+
parameters = pydantic.create_model(fname, **fields_dict).model_json_schema()
|
| 421 |
+
defs = parameters.pop("$defs", {})
|
| 422 |
+
# flatten the defs
|
| 423 |
+
for name, value in defs.items():
|
| 424 |
+
unpack_defs(value, defs)
|
| 425 |
+
unpack_defs(parameters, defs)
|
| 426 |
+
|
| 427 |
+
# 5. Nullable fields:
|
| 428 |
+
# * https://github.com/pydantic/pydantic/issues/1270
|
| 429 |
+
# * https://stackoverflow.com/a/58841311
|
| 430 |
+
# * https://github.com/pydantic/pydantic/discussions/4872
|
| 431 |
+
convert_to_nullable(parameters)
|
| 432 |
+
add_object_type(parameters)
|
| 433 |
+
# Postprocessing
|
| 434 |
+
# 4. Suppress unnecessary title generation:
|
| 435 |
+
# * https://github.com/pydantic/pydantic/issues/1051
|
| 436 |
+
# * http://cl/586221780
|
| 437 |
+
strip_titles(parameters)
|
| 438 |
+
return parameters
|
| 439 |
+
|
| 440 |
+
|
| 441 |
+
def unpack_defs(schema, defs):
|
| 442 |
+
properties = schema.get("properties", None)
|
| 443 |
+
if properties is None:
|
| 444 |
+
return
|
| 445 |
+
|
| 446 |
+
for name, value in properties.items():
|
| 447 |
+
ref_key = value.get("$ref", None)
|
| 448 |
+
if ref_key is not None:
|
| 449 |
+
ref = defs[ref_key.split("defs/")[-1]]
|
| 450 |
+
unpack_defs(ref, defs)
|
| 451 |
+
properties[name] = ref
|
| 452 |
+
continue
|
| 453 |
+
|
| 454 |
+
anyof = value.get("anyOf", None)
|
| 455 |
+
if anyof is not None:
|
| 456 |
+
for i, atype in enumerate(anyof):
|
| 457 |
+
ref_key = atype.get("$ref", None)
|
| 458 |
+
if ref_key is not None:
|
| 459 |
+
ref = defs[ref_key.split("defs/")[-1]]
|
| 460 |
+
unpack_defs(ref, defs)
|
| 461 |
+
anyof[i] = ref
|
| 462 |
+
continue
|
| 463 |
+
|
| 464 |
+
items = value.get("items", None)
|
| 465 |
+
if items is not None:
|
| 466 |
+
ref_key = items.get("$ref", None)
|
| 467 |
+
if ref_key is not None:
|
| 468 |
+
ref = defs[ref_key.split("defs/")[-1]]
|
| 469 |
+
unpack_defs(ref, defs)
|
| 470 |
+
value["items"] = ref
|
| 471 |
+
continue
|
| 472 |
+
|
| 473 |
+
|
| 474 |
+
def strip_titles(schema):
|
| 475 |
+
title = schema.pop("title", None)
|
| 476 |
+
|
| 477 |
+
properties = schema.get("properties", None)
|
| 478 |
+
if properties is not None:
|
| 479 |
+
for name, value in properties.items():
|
| 480 |
+
strip_titles(value)
|
| 481 |
+
|
| 482 |
+
items = schema.get("items", None)
|
| 483 |
+
if items is not None:
|
| 484 |
+
strip_titles(items)
|
| 485 |
+
|
| 486 |
+
|
| 487 |
+
def add_object_type(schema):
|
| 488 |
+
properties = schema.get("properties", None)
|
| 489 |
+
if properties is not None:
|
| 490 |
+
schema.pop("required", None)
|
| 491 |
+
schema["type"] = "object"
|
| 492 |
+
for name, value in properties.items():
|
| 493 |
+
add_object_type(value)
|
| 494 |
+
|
| 495 |
+
items = schema.get("items", None)
|
| 496 |
+
if items is not None:
|
| 497 |
+
add_object_type(items)
|
| 498 |
+
|
| 499 |
+
|
| 500 |
+
def convert_to_nullable(schema):
|
| 501 |
+
anyof = schema.pop("anyOf", None)
|
| 502 |
+
if anyof is not None:
|
| 503 |
+
if len(anyof) != 2:
|
| 504 |
+
raise ValueError(
|
| 505 |
+
"Invalid input: Type Unions are not supported, except for `Optional` types. "
|
| 506 |
+
"Please provide an `Optional` type or a non-Union type."
|
| 507 |
+
)
|
| 508 |
+
a, b = anyof
|
| 509 |
+
if a == {"type": "null"}:
|
| 510 |
+
schema.update(b)
|
| 511 |
+
elif b == {"type": "null"}:
|
| 512 |
+
schema.update(a)
|
| 513 |
+
else:
|
| 514 |
+
raise ValueError(
|
| 515 |
+
"Invalid input: Type Unions are not supported, except for `Optional` types. "
|
| 516 |
+
"Please provide an `Optional` type or a non-Union type."
|
| 517 |
+
)
|
| 518 |
+
schema["nullable"] = True
|
| 519 |
+
|
| 520 |
+
properties = schema.get("properties", None)
|
| 521 |
+
if properties is not None:
|
| 522 |
+
for name, value in properties.items():
|
| 523 |
+
convert_to_nullable(value)
|
| 524 |
+
|
| 525 |
+
items = schema.get("items", None)
|
| 526 |
+
if items is not None:
|
| 527 |
+
convert_to_nullable(items)
|
| 528 |
+
|
| 529 |
+
|
| 530 |
+
def _rename_schema_fields(schema):
|
| 531 |
+
if schema is None:
|
| 532 |
+
return schema
|
| 533 |
+
|
| 534 |
+
schema = schema.copy()
|
| 535 |
+
|
| 536 |
+
type_ = schema.pop("type", None)
|
| 537 |
+
if type_ is not None:
|
| 538 |
+
schema["type_"] = type_.upper()
|
| 539 |
+
|
| 540 |
+
format_ = schema.pop("format", None)
|
| 541 |
+
if format_ is not None:
|
| 542 |
+
schema["format_"] = format_
|
| 543 |
+
|
| 544 |
+
items = schema.pop("items", None)
|
| 545 |
+
if items is not None:
|
| 546 |
+
schema["items"] = _rename_schema_fields(items)
|
| 547 |
+
|
| 548 |
+
properties = schema.pop("properties", None)
|
| 549 |
+
if properties is not None:
|
| 550 |
+
schema["properties"] = {k: _rename_schema_fields(v) for k, v in properties.items()}
|
| 551 |
+
|
| 552 |
+
return schema
|
| 553 |
+
|
| 554 |
+
|
| 555 |
+
class FunctionDeclaration:
|
| 556 |
+
def __init__(self, *, name: str, description: str, parameters: dict[str, Any] | None = None):
|
| 557 |
+
"""A class wrapping a `protos.FunctionDeclaration`, describes a function for `genai.GenerativeModel`'s `tools`."""
|
| 558 |
+
self._proto = protos.FunctionDeclaration(
|
| 559 |
+
name=name, description=description, parameters=_rename_schema_fields(parameters)
|
| 560 |
+
)
|
| 561 |
+
|
| 562 |
+
@property
|
| 563 |
+
def name(self) -> str:
|
| 564 |
+
return self._proto.name
|
| 565 |
+
|
| 566 |
+
@property
|
| 567 |
+
def description(self) -> str:
|
| 568 |
+
return self._proto.description
|
| 569 |
+
|
| 570 |
+
@property
|
| 571 |
+
def parameters(self) -> protos.Schema:
|
| 572 |
+
return self._proto.parameters
|
| 573 |
+
|
| 574 |
+
@classmethod
|
| 575 |
+
def from_proto(cls, proto) -> FunctionDeclaration:
|
| 576 |
+
self = cls(name="", description="", parameters={})
|
| 577 |
+
self._proto = proto
|
| 578 |
+
return self
|
| 579 |
+
|
| 580 |
+
def to_proto(self) -> protos.FunctionDeclaration:
|
| 581 |
+
return self._proto
|
| 582 |
+
|
| 583 |
+
@staticmethod
|
| 584 |
+
def from_function(function: Callable[..., Any], descriptions: dict[str, str] | None = None):
|
| 585 |
+
"""Builds a `CallableFunctionDeclaration` from a python function.
|
| 586 |
+
|
| 587 |
+
The function should have type annotations.
|
| 588 |
+
|
| 589 |
+
This method is able to generate the schema for arguments annotated with types:
|
| 590 |
+
|
| 591 |
+
`AllowedTypes = float | int | str | list[AllowedTypes] | dict`
|
| 592 |
+
|
| 593 |
+
This method does not yet build a schema for `TypedDict`, that would allow you to specify the dictionary
|
| 594 |
+
contents. But you can build these manually.
|
| 595 |
+
"""
|
| 596 |
+
|
| 597 |
+
if descriptions is None:
|
| 598 |
+
descriptions = {}
|
| 599 |
+
|
| 600 |
+
schema = _schema_for_function(function, descriptions=descriptions)
|
| 601 |
+
|
| 602 |
+
return CallableFunctionDeclaration(**schema, function=function)
|
| 603 |
+
|
| 604 |
+
|
| 605 |
+
StructType = dict[str, "ValueType"]
|
| 606 |
+
ValueType = Union[float, str, bool, StructType, list["ValueType"], None]
|
| 607 |
+
|
| 608 |
+
|
| 609 |
+
class CallableFunctionDeclaration(FunctionDeclaration):
|
| 610 |
+
"""An extension of `FunctionDeclaration` that can be built from a python function, and is callable.
|
| 611 |
+
|
| 612 |
+
Note: The python function must have type annotations.
|
| 613 |
+
"""
|
| 614 |
+
|
| 615 |
+
def __init__(
|
| 616 |
+
self,
|
| 617 |
+
*,
|
| 618 |
+
name: str,
|
| 619 |
+
description: str,
|
| 620 |
+
parameters: dict[str, Any] | None = None,
|
| 621 |
+
function: Callable[..., Any],
|
| 622 |
+
):
|
| 623 |
+
super().__init__(name=name, description=description, parameters=parameters)
|
| 624 |
+
self.function = function
|
| 625 |
+
|
| 626 |
+
def __call__(self, fc: protos.FunctionCall) -> protos.FunctionResponse:
|
| 627 |
+
result = self.function(**fc.args)
|
| 628 |
+
if not isinstance(result, dict):
|
| 629 |
+
result = {"result": result}
|
| 630 |
+
return protos.FunctionResponse(name=fc.name, response=result)
|
| 631 |
+
|
| 632 |
+
|
| 633 |
+
FunctionDeclarationType = Union[
|
| 634 |
+
FunctionDeclaration,
|
| 635 |
+
protos.FunctionDeclaration,
|
| 636 |
+
dict[str, Any],
|
| 637 |
+
Callable[..., Any],
|
| 638 |
+
]
|
| 639 |
+
|
| 640 |
+
|
| 641 |
+
def _make_function_declaration(
|
| 642 |
+
fun: FunctionDeclarationType,
|
| 643 |
+
) -> FunctionDeclaration | protos.FunctionDeclaration:
|
| 644 |
+
if isinstance(fun, (FunctionDeclaration, protos.FunctionDeclaration)):
|
| 645 |
+
return fun
|
| 646 |
+
elif isinstance(fun, dict):
|
| 647 |
+
if "function" in fun:
|
| 648 |
+
return CallableFunctionDeclaration(**fun)
|
| 649 |
+
else:
|
| 650 |
+
return FunctionDeclaration(**fun)
|
| 651 |
+
elif callable(fun):
|
| 652 |
+
return CallableFunctionDeclaration.from_function(fun)
|
| 653 |
+
else:
|
| 654 |
+
raise TypeError(
|
| 655 |
+
"Invalid input type. Expected an instance of `genai.FunctionDeclarationType`.\n"
|
| 656 |
+
f"However, received an object of type: {type(fun)}.\n"
|
| 657 |
+
f"Object Value: {fun}"
|
| 658 |
+
)
|
| 659 |
+
|
| 660 |
+
|
| 661 |
+
def _encode_fd(fd: FunctionDeclaration | protos.FunctionDeclaration) -> protos.FunctionDeclaration:
|
| 662 |
+
if isinstance(fd, protos.FunctionDeclaration):
|
| 663 |
+
return fd
|
| 664 |
+
|
| 665 |
+
return fd.to_proto()
|
| 666 |
+
|
| 667 |
+
|
| 668 |
+
class DynamicRetrievalConfigDict(TypedDict):
|
| 669 |
+
mode: protos.DynamicRetrievalConfig.mode
|
| 670 |
+
dynamic_threshold: float
|
| 671 |
+
|
| 672 |
+
|
| 673 |
+
DynamicRetrievalConfig = Union[protos.DynamicRetrievalConfig, DynamicRetrievalConfigDict]
|
| 674 |
+
|
| 675 |
+
|
| 676 |
+
class GoogleSearchRetrievalDict(TypedDict):
|
| 677 |
+
dynamic_retrieval_config: DynamicRetrievalConfig
|
| 678 |
+
|
| 679 |
+
|
| 680 |
+
GoogleSearchRetrievalType = Union[protos.GoogleSearchRetrieval, GoogleSearchRetrievalDict]
|
| 681 |
+
|
| 682 |
+
|
| 683 |
+
def _make_google_search_retrieval(gsr: GoogleSearchRetrievalType):
|
| 684 |
+
if isinstance(gsr, protos.GoogleSearchRetrieval):
|
| 685 |
+
return gsr
|
| 686 |
+
elif isinstance(gsr, Mapping):
|
| 687 |
+
drc = gsr.get("dynamic_retrieval_config", None)
|
| 688 |
+
if drc is not None and isinstance(drc, Mapping):
|
| 689 |
+
mode = drc.get("mode", None)
|
| 690 |
+
if mode is not None:
|
| 691 |
+
mode = to_mode(mode)
|
| 692 |
+
gsr = gsr.copy()
|
| 693 |
+
gsr["dynamic_retrieval_config"]["mode"] = mode
|
| 694 |
+
return protos.GoogleSearchRetrieval(gsr)
|
| 695 |
+
else:
|
| 696 |
+
raise TypeError(
|
| 697 |
+
"Invalid input type. Expected an instance of `genai.GoogleSearchRetrieval`.\n"
|
| 698 |
+
f"However, received an object of type: {type(gsr)}.\n"
|
| 699 |
+
f"Object Value: {gsr}"
|
| 700 |
+
)
|
| 701 |
+
|
| 702 |
+
|
| 703 |
+
class Tool:
|
| 704 |
+
"""A wrapper for `protos.Tool`, Contains a collection of related `FunctionDeclaration` objects,
|
| 705 |
+
protos.CodeExecution object, and protos.GoogleSearchRetrieval object."""
|
| 706 |
+
|
| 707 |
+
def __init__(
|
| 708 |
+
self,
|
| 709 |
+
*,
|
| 710 |
+
function_declarations: Iterable[FunctionDeclarationType] | None = None,
|
| 711 |
+
google_search_retrieval: GoogleSearchRetrievalType | None = None,
|
| 712 |
+
code_execution: protos.CodeExecution | None = None,
|
| 713 |
+
):
|
| 714 |
+
# The main path doesn't use this but is seems useful.
|
| 715 |
+
if function_declarations is not None:
|
| 716 |
+
self._function_declarations = [
|
| 717 |
+
_make_function_declaration(f) for f in function_declarations
|
| 718 |
+
]
|
| 719 |
+
self._index = {}
|
| 720 |
+
for fd in self._function_declarations:
|
| 721 |
+
name = fd.name
|
| 722 |
+
if name in self._index:
|
| 723 |
+
raise ValueError("")
|
| 724 |
+
self._index[fd.name] = fd
|
| 725 |
+
else:
|
| 726 |
+
# Consistent fields
|
| 727 |
+
self._function_declarations = []
|
| 728 |
+
self._index = {}
|
| 729 |
+
|
| 730 |
+
if google_search_retrieval is not None:
|
| 731 |
+
self._google_search_retrieval = _make_google_search_retrieval(google_search_retrieval)
|
| 732 |
+
else:
|
| 733 |
+
self._google_search_retrieval = None
|
| 734 |
+
|
| 735 |
+
self._proto = protos.Tool(
|
| 736 |
+
function_declarations=[_encode_fd(fd) for fd in self._function_declarations],
|
| 737 |
+
google_search_retrieval=google_search_retrieval,
|
| 738 |
+
code_execution=code_execution,
|
| 739 |
+
)
|
| 740 |
+
|
| 741 |
+
@property
|
| 742 |
+
def function_declarations(self) -> list[FunctionDeclaration | protos.FunctionDeclaration]:
|
| 743 |
+
return self._function_declarations
|
| 744 |
+
|
| 745 |
+
@property
|
| 746 |
+
def google_search_retrieval(self) -> protos.GoogleSearchRetrieval:
|
| 747 |
+
return self._google_search_retrieval
|
| 748 |
+
|
| 749 |
+
@property
|
| 750 |
+
def code_execution(self) -> protos.CodeExecution:
|
| 751 |
+
return self._proto.code_execution
|
| 752 |
+
|
| 753 |
+
def __getitem__(
|
| 754 |
+
self, name: str | protos.FunctionCall
|
| 755 |
+
) -> FunctionDeclaration | protos.FunctionDeclaration:
|
| 756 |
+
if not isinstance(name, str):
|
| 757 |
+
name = name.name
|
| 758 |
+
|
| 759 |
+
return self._index[name]
|
| 760 |
+
|
| 761 |
+
def __call__(self, fc: protos.FunctionCall) -> protos.FunctionResponse | None:
|
| 762 |
+
declaration = self[fc]
|
| 763 |
+
if not callable(declaration):
|
| 764 |
+
return None
|
| 765 |
+
|
| 766 |
+
return declaration(fc)
|
| 767 |
+
|
| 768 |
+
def to_proto(self):
|
| 769 |
+
return self._proto
|
| 770 |
+
|
| 771 |
+
|
| 772 |
+
class ToolDict(TypedDict):
|
| 773 |
+
function_declarations: list[FunctionDeclarationType]
|
| 774 |
+
|
| 775 |
+
|
| 776 |
+
ToolType = Union[
|
| 777 |
+
str, Tool, protos.Tool, ToolDict, Iterable[FunctionDeclarationType], FunctionDeclarationType
|
| 778 |
+
]
|
| 779 |
+
|
| 780 |
+
|
| 781 |
+
def _make_tool(tool: ToolType) -> Tool:
|
| 782 |
+
if isinstance(tool, Tool):
|
| 783 |
+
return tool
|
| 784 |
+
elif isinstance(tool, protos.Tool):
|
| 785 |
+
if "code_execution" in tool:
|
| 786 |
+
code_execution = tool.code_execution
|
| 787 |
+
else:
|
| 788 |
+
code_execution = None
|
| 789 |
+
|
| 790 |
+
if "google_search_retrieval" in tool:
|
| 791 |
+
google_search_retrieval = tool.google_search_retrieval
|
| 792 |
+
else:
|
| 793 |
+
google_search_retrieval = None
|
| 794 |
+
|
| 795 |
+
return Tool(
|
| 796 |
+
function_declarations=tool.function_declarations,
|
| 797 |
+
google_search_retrieval=google_search_retrieval,
|
| 798 |
+
code_execution=code_execution,
|
| 799 |
+
)
|
| 800 |
+
elif isinstance(tool, dict):
|
| 801 |
+
if (
|
| 802 |
+
"function_declarations" in tool
|
| 803 |
+
or "google_search_retrieval" in tool
|
| 804 |
+
or "code_execution" in tool
|
| 805 |
+
):
|
| 806 |
+
return Tool(**tool)
|
| 807 |
+
else:
|
| 808 |
+
fd = tool
|
| 809 |
+
return Tool(function_declarations=[protos.FunctionDeclaration(**fd)])
|
| 810 |
+
elif isinstance(tool, str):
|
| 811 |
+
if tool.lower() == "code_execution":
|
| 812 |
+
return Tool(code_execution=protos.CodeExecution())
|
| 813 |
+
# Check to see if one of the mode enums matches
|
| 814 |
+
elif tool.lower() == "google_search_retrieval":
|
| 815 |
+
return Tool(google_search_retrieval=protos.GoogleSearchRetrieval())
|
| 816 |
+
else:
|
| 817 |
+
raise ValueError(
|
| 818 |
+
"The only string that can be passed as a tool is 'code_execution', or one of the specified values for the `mode` parameter for google_search_retrieval."
|
| 819 |
+
)
|
| 820 |
+
elif isinstance(tool, protos.CodeExecution):
|
| 821 |
+
return Tool(code_execution=tool)
|
| 822 |
+
elif isinstance(tool, protos.GoogleSearchRetrieval):
|
| 823 |
+
return Tool(google_search_retrieval=tool)
|
| 824 |
+
elif isinstance(tool, Iterable):
|
| 825 |
+
return Tool(function_declarations=tool)
|
| 826 |
+
else:
|
| 827 |
+
try:
|
| 828 |
+
return Tool(function_declarations=[tool])
|
| 829 |
+
except Exception as e:
|
| 830 |
+
raise TypeError(
|
| 831 |
+
"Invalid input type. Expected an instance of `genai.ToolType`.\n"
|
| 832 |
+
f"However, received an object of type: {type(tool)}.\n"
|
| 833 |
+
f"Object Value: {tool}"
|
| 834 |
+
) from e
|
| 835 |
+
|
| 836 |
+
|
| 837 |
+
class FunctionLibrary:
|
| 838 |
+
"""A container for a set of `Tool` objects, manages lookup and execution of their functions."""
|
| 839 |
+
|
| 840 |
+
def __init__(self, tools: Iterable[ToolType]):
|
| 841 |
+
tools = _make_tools(tools)
|
| 842 |
+
self._tools = list(tools)
|
| 843 |
+
self._index = {}
|
| 844 |
+
for tool in self._tools:
|
| 845 |
+
for declaration in tool.function_declarations:
|
| 846 |
+
name = declaration.name
|
| 847 |
+
if name in self._index:
|
| 848 |
+
raise ValueError(
|
| 849 |
+
f"Invalid operation: A `FunctionDeclaration` named '{name}' is already defined. "
|
| 850 |
+
"Each `FunctionDeclaration` must have a unique name. Please use a different name."
|
| 851 |
+
)
|
| 852 |
+
self._index[declaration.name] = declaration
|
| 853 |
+
|
| 854 |
+
def __getitem__(
|
| 855 |
+
self, name: str | protos.FunctionCall
|
| 856 |
+
) -> FunctionDeclaration | protos.FunctionDeclaration:
|
| 857 |
+
if not isinstance(name, str):
|
| 858 |
+
name = name.name
|
| 859 |
+
|
| 860 |
+
return self._index[name]
|
| 861 |
+
|
| 862 |
+
def __call__(self, fc: protos.FunctionCall) -> protos.Part | None:
|
| 863 |
+
declaration = self[fc]
|
| 864 |
+
if not callable(declaration):
|
| 865 |
+
return None
|
| 866 |
+
|
| 867 |
+
response = declaration(fc)
|
| 868 |
+
return protos.Part(function_response=response)
|
| 869 |
+
|
| 870 |
+
def to_proto(self):
|
| 871 |
+
return [tool.to_proto() for tool in self._tools]
|
| 872 |
+
|
| 873 |
+
|
| 874 |
+
ToolsType = Union[Iterable[ToolType], ToolType]
|
| 875 |
+
|
| 876 |
+
|
| 877 |
+
def _make_tools(tools: ToolsType) -> list[Tool]:
|
| 878 |
+
if isinstance(tools, str):
|
| 879 |
+
if tools.lower() == "code_execution" or tools.lower() == "google_search_retrieval":
|
| 880 |
+
return [_make_tool(tools)]
|
| 881 |
+
else:
|
| 882 |
+
raise ValueError("The only string that can be passed as a tool is 'code_execution'.")
|
| 883 |
+
elif isinstance(tools, Iterable) and not isinstance(tools, Mapping):
|
| 884 |
+
tools = [_make_tool(t) for t in tools]
|
| 885 |
+
if len(tools) > 1 and all(len(t.function_declarations) == 1 for t in tools):
|
| 886 |
+
# flatten into a single tool.
|
| 887 |
+
tools = [_make_tool([t.function_declarations[0] for t in tools])]
|
| 888 |
+
return tools
|
| 889 |
+
else:
|
| 890 |
+
tool = tools
|
| 891 |
+
return [_make_tool(tool)]
|
| 892 |
+
|
| 893 |
+
|
| 894 |
+
FunctionLibraryType = Union[FunctionLibrary, ToolsType]
|
| 895 |
+
|
| 896 |
+
|
| 897 |
+
def to_function_library(lib: FunctionLibraryType | None) -> FunctionLibrary | None:
|
| 898 |
+
if lib is None:
|
| 899 |
+
return lib
|
| 900 |
+
elif isinstance(lib, FunctionLibrary):
|
| 901 |
+
return lib
|
| 902 |
+
else:
|
| 903 |
+
return FunctionLibrary(tools=lib)
|
| 904 |
+
|
| 905 |
+
|
| 906 |
+
FunctionCallingMode = protos.FunctionCallingConfig.Mode
|
| 907 |
+
|
| 908 |
+
# fmt: off
|
| 909 |
+
_FUNCTION_CALLING_MODE = {
|
| 910 |
+
1: FunctionCallingMode.AUTO,
|
| 911 |
+
FunctionCallingMode.AUTO: FunctionCallingMode.AUTO,
|
| 912 |
+
"mode_auto": FunctionCallingMode.AUTO,
|
| 913 |
+
"auto": FunctionCallingMode.AUTO,
|
| 914 |
+
|
| 915 |
+
2: FunctionCallingMode.ANY,
|
| 916 |
+
FunctionCallingMode.ANY: FunctionCallingMode.ANY,
|
| 917 |
+
"mode_any": FunctionCallingMode.ANY,
|
| 918 |
+
"any": FunctionCallingMode.ANY,
|
| 919 |
+
|
| 920 |
+
3: FunctionCallingMode.NONE,
|
| 921 |
+
FunctionCallingMode.NONE: FunctionCallingMode.NONE,
|
| 922 |
+
"mode_none": FunctionCallingMode.NONE,
|
| 923 |
+
"none": FunctionCallingMode.NONE,
|
| 924 |
+
}
|
| 925 |
+
# fmt: on
|
| 926 |
+
|
| 927 |
+
FunctionCallingModeType = Union[FunctionCallingMode, str, int]
|
| 928 |
+
|
| 929 |
+
|
| 930 |
+
def to_function_calling_mode(x: FunctionCallingModeType) -> FunctionCallingMode:
|
| 931 |
+
if isinstance(x, str):
|
| 932 |
+
x = x.lower()
|
| 933 |
+
return _FUNCTION_CALLING_MODE[x]
|
| 934 |
+
|
| 935 |
+
|
| 936 |
+
class FunctionCallingConfigDict(TypedDict):
|
| 937 |
+
mode: FunctionCallingModeType
|
| 938 |
+
allowed_function_names: list[str]
|
| 939 |
+
|
| 940 |
+
|
| 941 |
+
FunctionCallingConfigType = Union[
|
| 942 |
+
FunctionCallingModeType, FunctionCallingConfigDict, protos.FunctionCallingConfig
|
| 943 |
+
]
|
| 944 |
+
|
| 945 |
+
|
| 946 |
+
def to_function_calling_config(obj: FunctionCallingConfigType) -> protos.FunctionCallingConfig:
|
| 947 |
+
if isinstance(obj, protos.FunctionCallingConfig):
|
| 948 |
+
return obj
|
| 949 |
+
elif isinstance(obj, (FunctionCallingMode, str, int)):
|
| 950 |
+
obj = {"mode": to_function_calling_mode(obj)}
|
| 951 |
+
elif isinstance(obj, dict):
|
| 952 |
+
obj = obj.copy()
|
| 953 |
+
mode = obj.pop("mode")
|
| 954 |
+
obj["mode"] = to_function_calling_mode(mode)
|
| 955 |
+
else:
|
| 956 |
+
raise TypeError(
|
| 957 |
+
"Invalid input type. Failed to convert input to `protos.FunctionCallingConfig`.\n"
|
| 958 |
+
f"Received an object of type: {type(obj)}.\n"
|
| 959 |
+
f"Object Value: {obj}"
|
| 960 |
+
)
|
| 961 |
+
|
| 962 |
+
return protos.FunctionCallingConfig(obj)
|
| 963 |
+
|
| 964 |
+
|
| 965 |
+
class ToolConfigDict:
|
| 966 |
+
function_calling_config: FunctionCallingConfigType
|
| 967 |
+
|
| 968 |
+
|
| 969 |
+
ToolConfigType = Union[ToolConfigDict, protos.ToolConfig]
|
| 970 |
+
|
| 971 |
+
|
| 972 |
+
def to_tool_config(obj: ToolConfigType) -> protos.ToolConfig:
|
| 973 |
+
if isinstance(obj, protos.ToolConfig):
|
| 974 |
+
return obj
|
| 975 |
+
elif isinstance(obj, dict):
|
| 976 |
+
fcc = obj.pop("function_calling_config")
|
| 977 |
+
fcc = to_function_calling_config(fcc)
|
| 978 |
+
obj["function_calling_config"] = fcc
|
| 979 |
+
return protos.ToolConfig(**obj)
|
| 980 |
+
else:
|
| 981 |
+
raise TypeError(
|
| 982 |
+
"Invalid input type. Failed to convert input to `protos.ToolConfig`.\n"
|
| 983 |
+
f"Received an object of type: {type(obj)}.\n"
|
| 984 |
+
f"Object Value: {obj}"
|
| 985 |
+
)
|
.venv/lib/python3.11/site-packages/google/generativeai/types/discuss_types.py
ADDED
|
@@ -0,0 +1,208 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Copyright 2023 Google LLC
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Type definitions for the discuss service."""
|
| 16 |
+
|
| 17 |
+
import abc
|
| 18 |
+
import dataclasses
|
| 19 |
+
from typing import Any, Dict, Union, Iterable, Optional, Tuple, List
|
| 20 |
+
from typing_extensions import TypedDict
|
| 21 |
+
|
| 22 |
+
import google.ai.generativelanguage as glm
|
| 23 |
+
from google.generativeai import string_utils
|
| 24 |
+
|
| 25 |
+
from google.generativeai.types import safety_types
|
| 26 |
+
from google.generativeai.types import citation_types
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
__all__ = [
|
| 30 |
+
"MessageDict",
|
| 31 |
+
"MessageOptions",
|
| 32 |
+
"MessagesOptions",
|
| 33 |
+
"ExampleDict",
|
| 34 |
+
"ExampleOptions",
|
| 35 |
+
"ExamplesOptions",
|
| 36 |
+
"MessagePromptDict",
|
| 37 |
+
"MessagePromptOptions",
|
| 38 |
+
"ResponseDict",
|
| 39 |
+
"ChatResponse",
|
| 40 |
+
"AuthorError",
|
| 41 |
+
]
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class TokenCount(TypedDict):
|
| 45 |
+
token_count: int
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class MessageDict(TypedDict):
|
| 49 |
+
"""A dict representation of a `glm.Message`."""
|
| 50 |
+
|
| 51 |
+
author: str
|
| 52 |
+
content: str
|
| 53 |
+
citation_metadata: Optional[citation_types.CitationMetadataDict]
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
MessageOptions = Union[str, MessageDict, glm.Message]
|
| 57 |
+
MESSAGE_OPTIONS = (str, dict, glm.Message)
|
| 58 |
+
|
| 59 |
+
MessagesOptions = Union[
|
| 60 |
+
MessageOptions,
|
| 61 |
+
Iterable[MessageOptions],
|
| 62 |
+
]
|
| 63 |
+
MESSAGES_OPTIONS = (MESSAGE_OPTIONS, Iterable)
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
class ExampleDict(TypedDict):
|
| 67 |
+
"""A dict representation of a `glm.Example`."""
|
| 68 |
+
|
| 69 |
+
input: MessageOptions
|
| 70 |
+
output: MessageOptions
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
ExampleOptions = Union[
|
| 74 |
+
Tuple[MessageOptions, MessageOptions],
|
| 75 |
+
Iterable[MessageOptions],
|
| 76 |
+
ExampleDict,
|
| 77 |
+
glm.Example,
|
| 78 |
+
]
|
| 79 |
+
EXAMPLE_OPTIONS = (glm.Example, dict, Iterable)
|
| 80 |
+
ExamplesOptions = Union[ExampleOptions, Iterable[ExampleOptions]]
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
class MessagePromptDict(TypedDict, total=False):
|
| 84 |
+
"""A dict representation of a `glm.MessagePrompt`."""
|
| 85 |
+
|
| 86 |
+
context: str
|
| 87 |
+
examples: ExamplesOptions
|
| 88 |
+
messages: MessagesOptions
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
MessagePromptOptions = Union[
|
| 92 |
+
str,
|
| 93 |
+
glm.Message,
|
| 94 |
+
Iterable[Union[str, glm.Message]],
|
| 95 |
+
MessagePromptDict,
|
| 96 |
+
glm.MessagePrompt,
|
| 97 |
+
]
|
| 98 |
+
MESSAGE_PROMPT_KEYS = {"context", "examples", "messages"}
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
class ResponseDict(TypedDict):
|
| 102 |
+
"""A dict representation of a `glm.GenerateMessageResponse`."""
|
| 103 |
+
|
| 104 |
+
messages: List[MessageDict]
|
| 105 |
+
candidates: List[MessageDict]
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
@string_utils.prettyprint
|
| 109 |
+
@dataclasses.dataclass(init=False)
|
| 110 |
+
class ChatResponse(abc.ABC):
|
| 111 |
+
"""A chat response from the model.
|
| 112 |
+
|
| 113 |
+
* Use `response.last` (settable) for easy access to the text of the last response.
|
| 114 |
+
(`messages[-1]['content']`)
|
| 115 |
+
* Use `response.messages` to access the message history (including `.last`).
|
| 116 |
+
* Use `response.candidates` to access all the responses generated by the model.
|
| 117 |
+
|
| 118 |
+
Other attributes are just saved from the arguments to `genai.chat`, so you
|
| 119 |
+
can easily continue a conversation:
|
| 120 |
+
|
| 121 |
+
```
|
| 122 |
+
import google.generativeai as genai
|
| 123 |
+
|
| 124 |
+
genai.configure(api_key=os.environ['GOOGLE_API_KEY'])
|
| 125 |
+
|
| 126 |
+
response = genai.chat(messages=["Hello."])
|
| 127 |
+
print(response.last) # 'Hello! What can I help you with?'
|
| 128 |
+
response.reply("Can you tell me a joke?")
|
| 129 |
+
```
|
| 130 |
+
|
| 131 |
+
See `genai.chat` for more details.
|
| 132 |
+
|
| 133 |
+
Attributes:
|
| 134 |
+
candidates: A list of candidate responses from the model.
|
| 135 |
+
|
| 136 |
+
The top candidate is appended to the `messages` field.
|
| 137 |
+
|
| 138 |
+
This list will contain a *maximum* of `candidate_count` candidates.
|
| 139 |
+
It may contain fewer (duplicates are dropped), it will contain at least one.
|
| 140 |
+
|
| 141 |
+
Note: The `temperature` field affects the variability of the responses. Low
|
| 142 |
+
temperatures will return few candidates. Setting `temperature=0` is deterministic,
|
| 143 |
+
so it will only ever return one candidate.
|
| 144 |
+
filters: This indicates which `types.SafetyCategory`(s) blocked a
|
| 145 |
+
candidate from this response, the lowest `types.HarmProbability`
|
| 146 |
+
that triggered a block, and the `types.HarmThreshold` setting for that category.
|
| 147 |
+
This indicates the smallest change to the `types.SafetySettings` that would be
|
| 148 |
+
necessary to unblock at least 1 response.
|
| 149 |
+
|
| 150 |
+
The blocking is configured by the `types.SafetySettings` in the request (or the
|
| 151 |
+
default `types.SafetySettings` of the API).
|
| 152 |
+
messages: Contains all the `messages` that were passed when the model was called,
|
| 153 |
+
plus the top `candidate` message.
|
| 154 |
+
model: The model name.
|
| 155 |
+
context: Text that should be provided to the model first, to ground the response.
|
| 156 |
+
examples: Examples of what the model should generate.
|
| 157 |
+
messages: A snapshot of the conversation history sorted chronologically.
|
| 158 |
+
temperature: Controls the randomness of the output. Must be positive.
|
| 159 |
+
candidate_count: The **maximum** number of generated response messages to return.
|
| 160 |
+
top_k: The maximum number of tokens to consider when sampling.
|
| 161 |
+
top_p: The maximum cumulative probability of tokens to consider when sampling.
|
| 162 |
+
|
| 163 |
+
"""
|
| 164 |
+
|
| 165 |
+
model: str
|
| 166 |
+
context: str
|
| 167 |
+
examples: List[ExampleDict]
|
| 168 |
+
messages: List[Optional[MessageDict]]
|
| 169 |
+
temperature: Optional[float]
|
| 170 |
+
candidate_count: Optional[int]
|
| 171 |
+
candidates: List[MessageDict]
|
| 172 |
+
filters: List[safety_types.ContentFilterDict]
|
| 173 |
+
top_p: Optional[float] = None
|
| 174 |
+
top_k: Optional[float] = None
|
| 175 |
+
|
| 176 |
+
@property
|
| 177 |
+
@abc.abstractmethod
|
| 178 |
+
def last(self) -> Optional[str]:
|
| 179 |
+
"""A settable property that provides simple access to the last response string
|
| 180 |
+
|
| 181 |
+
A shortcut for `response.messages[0]['content']`.
|
| 182 |
+
"""
|
| 183 |
+
pass
|
| 184 |
+
|
| 185 |
+
def to_dict(self) -> Dict[str, Any]:
|
| 186 |
+
result = {
|
| 187 |
+
"model": self.model,
|
| 188 |
+
"context": self.context,
|
| 189 |
+
"examples": self.examples,
|
| 190 |
+
"messages": self.messages,
|
| 191 |
+
"temperature": self.temperature,
|
| 192 |
+
"candidate_count": self.candidate_count,
|
| 193 |
+
"top_p": self.top_p,
|
| 194 |
+
"top_k": self.top_k,
|
| 195 |
+
"candidates": self.candidates,
|
| 196 |
+
}
|
| 197 |
+
return result
|
| 198 |
+
|
| 199 |
+
@abc.abstractmethod
|
| 200 |
+
def reply(self, message: MessageOptions) -> "ChatResponse":
|
| 201 |
+
"Add a message to the conversation, and get the model's response."
|
| 202 |
+
pass
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
class AuthorError(Exception):
|
| 206 |
+
"""Raised by the `chat` (or `reply`) functions when the author list can't be normalized."""
|
| 207 |
+
|
| 208 |
+
pass
|
.venv/lib/python3.11/site-packages/google/generativeai/types/file_types.py
ADDED
|
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Copyright 2023 Google LLC
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
from __future__ import annotations
|
| 16 |
+
|
| 17 |
+
import datetime
|
| 18 |
+
from typing import Any, Union
|
| 19 |
+
from typing_extensions import TypedDict
|
| 20 |
+
|
| 21 |
+
from google.rpc.status_pb2 import Status
|
| 22 |
+
from google.generativeai.client import get_default_file_client
|
| 23 |
+
|
| 24 |
+
from google.generativeai import protos
|
| 25 |
+
|
| 26 |
+
import pprint
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class File:
|
| 30 |
+
def __init__(self, proto: protos.File | File | dict):
|
| 31 |
+
if isinstance(proto, File):
|
| 32 |
+
proto = proto.to_proto()
|
| 33 |
+
self._proto = protos.File(proto)
|
| 34 |
+
|
| 35 |
+
def to_proto(self) -> protos.File:
|
| 36 |
+
return self._proto
|
| 37 |
+
|
| 38 |
+
def to_dict(self) -> dict[str, Any]:
|
| 39 |
+
return type(self._proto).to_dict(self._proto, use_integers_for_enums=False)
|
| 40 |
+
|
| 41 |
+
def __str__(self):
|
| 42 |
+
def sort_key(pair):
|
| 43 |
+
name, value = pair
|
| 44 |
+
if name == "name":
|
| 45 |
+
return ""
|
| 46 |
+
elif "time" in name:
|
| 47 |
+
return "zz_" + name
|
| 48 |
+
else:
|
| 49 |
+
return name
|
| 50 |
+
|
| 51 |
+
dict_format = dict(sorted(self.to_dict().items(), key=sort_key))
|
| 52 |
+
dict_format = pprint.pformat(dict_format, sort_dicts=False)
|
| 53 |
+
dict_format = "{\n " + dict_format[1:]
|
| 54 |
+
dict_format = "\n ".join(dict_format.splitlines())
|
| 55 |
+
return dict_format.join(["genai.File(", ")"])
|
| 56 |
+
|
| 57 |
+
__repr__ = __str__
|
| 58 |
+
|
| 59 |
+
@property
|
| 60 |
+
def name(self) -> str:
|
| 61 |
+
return self._proto.name
|
| 62 |
+
|
| 63 |
+
@property
|
| 64 |
+
def display_name(self) -> str:
|
| 65 |
+
return self._proto.display_name
|
| 66 |
+
|
| 67 |
+
@property
|
| 68 |
+
def mime_type(self) -> str:
|
| 69 |
+
return self._proto.mime_type
|
| 70 |
+
|
| 71 |
+
@property
|
| 72 |
+
def size_bytes(self) -> int:
|
| 73 |
+
return self._proto.size_bytes
|
| 74 |
+
|
| 75 |
+
@property
|
| 76 |
+
def create_time(self) -> datetime.datetime:
|
| 77 |
+
return self._proto.create_time
|
| 78 |
+
|
| 79 |
+
@property
|
| 80 |
+
def update_time(self) -> datetime.datetime:
|
| 81 |
+
return self._proto.update_time
|
| 82 |
+
|
| 83 |
+
@property
|
| 84 |
+
def expiration_time(self) -> datetime.datetime:
|
| 85 |
+
return self._proto.expiration_time
|
| 86 |
+
|
| 87 |
+
@property
|
| 88 |
+
def sha256_hash(self) -> bytes:
|
| 89 |
+
return self._proto.sha256_hash
|
| 90 |
+
|
| 91 |
+
@property
|
| 92 |
+
def uri(self) -> str:
|
| 93 |
+
return self._proto.uri
|
| 94 |
+
|
| 95 |
+
@property
|
| 96 |
+
def state(self) -> protos.File.State:
|
| 97 |
+
return self._proto.state
|
| 98 |
+
|
| 99 |
+
@property
|
| 100 |
+
def video_metadata(self) -> protos.VideoMetadata:
|
| 101 |
+
return self._proto.video_metadata
|
| 102 |
+
|
| 103 |
+
@property
|
| 104 |
+
def error(self) -> Status:
|
| 105 |
+
return self._proto.error
|
| 106 |
+
|
| 107 |
+
def delete(self):
|
| 108 |
+
client = get_default_file_client()
|
| 109 |
+
client.delete_file(name=self.name)
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
class FileDataDict(TypedDict):
|
| 113 |
+
mime_type: str
|
| 114 |
+
file_uri: str
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
FileDataType = Union[FileDataDict, protos.FileData, protos.File, File]
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def to_file_data(file_data: FileDataType):
|
| 121 |
+
if isinstance(file_data, dict):
|
| 122 |
+
if "file_uri" in file_data:
|
| 123 |
+
file_data = protos.FileData(file_data)
|
| 124 |
+
else:
|
| 125 |
+
file_data = protos.File(file_data)
|
| 126 |
+
|
| 127 |
+
if isinstance(file_data, File):
|
| 128 |
+
file_data = file_data.to_proto()
|
| 129 |
+
|
| 130 |
+
if isinstance(file_data, protos.File):
|
| 131 |
+
file_data = protos.FileData(
|
| 132 |
+
mime_type=file_data.mime_type,
|
| 133 |
+
file_uri=file_data.uri,
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
if isinstance(file_data, protos.FileData):
|
| 137 |
+
return file_data
|
| 138 |
+
else:
|
| 139 |
+
raise TypeError(
|
| 140 |
+
f"Invalid input type. Failed to convert input to `FileData`.\n"
|
| 141 |
+
f"Received an object of type: {type(file_data)}.\n"
|
| 142 |
+
f"Object Value: {file_data}"
|
| 143 |
+
)
|
.venv/lib/python3.11/site-packages/google/generativeai/types/generation_types.py
ADDED
|
@@ -0,0 +1,759 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Copyright 2023 Google LLC
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
from __future__ import annotations
|
| 16 |
+
|
| 17 |
+
import collections
|
| 18 |
+
import contextlib
|
| 19 |
+
from collections.abc import Iterable, AsyncIterable, Mapping
|
| 20 |
+
import dataclasses
|
| 21 |
+
import itertools
|
| 22 |
+
import json
|
| 23 |
+
import sys
|
| 24 |
+
import textwrap
|
| 25 |
+
from typing import Union, Any
|
| 26 |
+
from typing_extensions import TypedDict
|
| 27 |
+
import types
|
| 28 |
+
|
| 29 |
+
import google.protobuf.json_format
|
| 30 |
+
import google.api_core.exceptions
|
| 31 |
+
|
| 32 |
+
from google.generativeai import protos
|
| 33 |
+
from google.generativeai import string_utils
|
| 34 |
+
from google.generativeai.types import content_types
|
| 35 |
+
from google.generativeai.responder import _rename_schema_fields
|
| 36 |
+
|
| 37 |
+
__all__ = [
|
| 38 |
+
"AsyncGenerateContentResponse",
|
| 39 |
+
"BlockedPromptException",
|
| 40 |
+
"StopCandidateException",
|
| 41 |
+
"IncompleteIterationError",
|
| 42 |
+
"BrokenResponseError",
|
| 43 |
+
"GenerationConfigDict",
|
| 44 |
+
"GenerationConfigType",
|
| 45 |
+
"GenerationConfig",
|
| 46 |
+
"GenerateContentResponse",
|
| 47 |
+
]
|
| 48 |
+
|
| 49 |
+
if sys.version_info < (3, 10):
|
| 50 |
+
|
| 51 |
+
def aiter(obj):
|
| 52 |
+
return obj.__aiter__()
|
| 53 |
+
|
| 54 |
+
async def anext(obj, default=None):
|
| 55 |
+
try:
|
| 56 |
+
return await obj.__anext__()
|
| 57 |
+
except StopAsyncIteration:
|
| 58 |
+
if default is not None:
|
| 59 |
+
return default
|
| 60 |
+
else:
|
| 61 |
+
raise
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
class BlockedPromptException(Exception):
|
| 65 |
+
pass
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
class StopCandidateException(Exception):
|
| 69 |
+
pass
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
class IncompleteIterationError(Exception):
|
| 73 |
+
pass
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
class BrokenResponseError(Exception):
|
| 77 |
+
pass
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
class GenerationConfigDict(TypedDict, total=False):
|
| 81 |
+
# TODO(markdaoust): Python 3.11+ use `NotRequired`, ref: https://peps.python.org/pep-0655/
|
| 82 |
+
candidate_count: int
|
| 83 |
+
stop_sequences: Iterable[str]
|
| 84 |
+
max_output_tokens: int
|
| 85 |
+
temperature: float
|
| 86 |
+
response_mime_type: str
|
| 87 |
+
response_schema: protos.Schema | Mapping[str, Any] # fmt: off
|
| 88 |
+
presence_penalty: float
|
| 89 |
+
frequency_penalty: float
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
@dataclasses.dataclass
|
| 93 |
+
class GenerationConfig:
|
| 94 |
+
"""A simple dataclass used to configure the generation parameters of `GenerativeModel.generate_content`.
|
| 95 |
+
|
| 96 |
+
Attributes:
|
| 97 |
+
candidate_count:
|
| 98 |
+
Number of generated responses to return.
|
| 99 |
+
stop_sequences:
|
| 100 |
+
The set of character sequences (up
|
| 101 |
+
to 5) that will stop output generation. If
|
| 102 |
+
specified, the API will stop at the first
|
| 103 |
+
appearance of a stop sequence. The stop sequence
|
| 104 |
+
will not be included as part of the response.
|
| 105 |
+
max_output_tokens:
|
| 106 |
+
The maximum number of tokens to include in a
|
| 107 |
+
candidate.
|
| 108 |
+
|
| 109 |
+
If unset, this will default to output_token_limit specified
|
| 110 |
+
in the model's specification.
|
| 111 |
+
temperature:
|
| 112 |
+
Controls the randomness of the output. Note: The
|
| 113 |
+
default value varies by model, see the `Model.temperature`
|
| 114 |
+
attribute of the `Model` returned the `genai.get_model`
|
| 115 |
+
function.
|
| 116 |
+
|
| 117 |
+
Values can range from [0.0,1.0], inclusive. A value closer
|
| 118 |
+
to 1.0 will produce responses that are more varied and
|
| 119 |
+
creative, while a value closer to 0.0 will typically result
|
| 120 |
+
in more straightforward responses from the model.
|
| 121 |
+
top_p:
|
| 122 |
+
Optional. The maximum cumulative probability of tokens to
|
| 123 |
+
consider when sampling.
|
| 124 |
+
|
| 125 |
+
The model uses combined Top-k and nucleus sampling.
|
| 126 |
+
|
| 127 |
+
Tokens are sorted based on their assigned probabilities so
|
| 128 |
+
that only the most likely tokens are considered. Top-k
|
| 129 |
+
sampling directly limits the maximum number of tokens to
|
| 130 |
+
consider, while Nucleus sampling limits number of tokens
|
| 131 |
+
based on the cumulative probability.
|
| 132 |
+
|
| 133 |
+
Note: The default value varies by model, see the
|
| 134 |
+
`Model.top_p` attribute of the `Model` returned the
|
| 135 |
+
`genai.get_model` function.
|
| 136 |
+
|
| 137 |
+
top_k (int):
|
| 138 |
+
Optional. The maximum number of tokens to consider when
|
| 139 |
+
sampling.
|
| 140 |
+
|
| 141 |
+
The model uses combined Top-k and nucleus sampling.
|
| 142 |
+
|
| 143 |
+
Top-k sampling considers the set of `top_k` most probable
|
| 144 |
+
tokens. Defaults to 40.
|
| 145 |
+
|
| 146 |
+
Note: The default value varies by model, see the
|
| 147 |
+
`Model.top_k` attribute of the `Model` returned the
|
| 148 |
+
`genai.get_model` function.
|
| 149 |
+
response_mime_type:
|
| 150 |
+
Optional. Output response mimetype of the generated candidate text.
|
| 151 |
+
|
| 152 |
+
Supported mimetype:
|
| 153 |
+
`text/plain`: (default) Text output.
|
| 154 |
+
`text/x-enum`: for use with a string-enum in `response_schema`
|
| 155 |
+
`application/json`: JSON response in the candidates.
|
| 156 |
+
|
| 157 |
+
response_schema:
|
| 158 |
+
Optional. Specifies the format of the JSON requested if response_mime_type is
|
| 159 |
+
`application/json`.
|
| 160 |
+
presence_penalty:
|
| 161 |
+
Optional.
|
| 162 |
+
frequency_penalty:
|
| 163 |
+
Optional.
|
| 164 |
+
"""
|
| 165 |
+
|
| 166 |
+
candidate_count: int | None = None
|
| 167 |
+
stop_sequences: Iterable[str] | None = None
|
| 168 |
+
max_output_tokens: int | None = None
|
| 169 |
+
temperature: float | None = None
|
| 170 |
+
top_p: float | None = None
|
| 171 |
+
top_k: int | None = None
|
| 172 |
+
response_mime_type: str | None = None
|
| 173 |
+
response_schema: protos.Schema | Mapping[str, Any] | type | None = None
|
| 174 |
+
presence_penalty: float | None = None
|
| 175 |
+
frequency_penalty: float | None = None
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
GenerationConfigType = Union[protos.GenerationConfig, GenerationConfigDict, GenerationConfig]
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
def _normalize_schema(generation_config):
|
| 182 |
+
# Convert response_schema to protos.Schema for request
|
| 183 |
+
response_schema = generation_config.get("response_schema", None)
|
| 184 |
+
if response_schema is None:
|
| 185 |
+
return
|
| 186 |
+
|
| 187 |
+
if isinstance(response_schema, protos.Schema):
|
| 188 |
+
return
|
| 189 |
+
|
| 190 |
+
if isinstance(response_schema, type):
|
| 191 |
+
response_schema = content_types._schema_for_class(response_schema)
|
| 192 |
+
elif isinstance(response_schema, types.GenericAlias):
|
| 193 |
+
if not str(response_schema).startswith("list["):
|
| 194 |
+
raise ValueError(
|
| 195 |
+
f"Invalid input: Could not understand the type of '{response_schema}'. "
|
| 196 |
+
"Expected one of the following types: `int`, `float`, `str`, `bool`, `enum`, "
|
| 197 |
+
"`typing_extensions.TypedDict`, `dataclass` or `list[...]`."
|
| 198 |
+
)
|
| 199 |
+
response_schema = content_types._schema_for_class(response_schema)
|
| 200 |
+
|
| 201 |
+
response_schema = _rename_schema_fields(response_schema)
|
| 202 |
+
generation_config["response_schema"] = protos.Schema(response_schema)
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
def to_generation_config_dict(generation_config: GenerationConfigType):
|
| 206 |
+
if generation_config is None:
|
| 207 |
+
return {}
|
| 208 |
+
elif isinstance(generation_config, protos.GenerationConfig):
|
| 209 |
+
schema = generation_config.response_schema
|
| 210 |
+
generation_config = type(generation_config).to_dict(
|
| 211 |
+
generation_config
|
| 212 |
+
) # pytype: disable=attribute-error
|
| 213 |
+
generation_config["response_schema"] = schema
|
| 214 |
+
return generation_config
|
| 215 |
+
elif isinstance(generation_config, GenerationConfig):
|
| 216 |
+
generation_config = dataclasses.asdict(generation_config)
|
| 217 |
+
_normalize_schema(generation_config)
|
| 218 |
+
return {key: value for key, value in generation_config.items() if value is not None}
|
| 219 |
+
elif hasattr(generation_config, "keys"):
|
| 220 |
+
generation_config = dict(generation_config)
|
| 221 |
+
_normalize_schema(generation_config)
|
| 222 |
+
return generation_config
|
| 223 |
+
else:
|
| 224 |
+
raise TypeError(
|
| 225 |
+
"Invalid input type. Expected a `dict` or `GenerationConfig` for `generation_config`.\n"
|
| 226 |
+
f"However, received an object of type: {type(generation_config)}.\n"
|
| 227 |
+
f"Object Value: {generation_config}"
|
| 228 |
+
)
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
def _join_citation_metadatas(
|
| 232 |
+
citation_metadatas: Iterable[protos.CitationMetadata],
|
| 233 |
+
):
|
| 234 |
+
citation_metadatas = list(citation_metadatas)
|
| 235 |
+
return citation_metadatas[-1]
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
def _join_safety_ratings_lists(
|
| 239 |
+
safety_ratings_lists: Iterable[list[protos.SafetyRating]],
|
| 240 |
+
):
|
| 241 |
+
ratings = {}
|
| 242 |
+
blocked = collections.defaultdict(list)
|
| 243 |
+
|
| 244 |
+
for safety_ratings_list in safety_ratings_lists:
|
| 245 |
+
for rating in safety_ratings_list:
|
| 246 |
+
ratings[rating.category] = rating.probability
|
| 247 |
+
blocked[rating.category].append(rating.blocked)
|
| 248 |
+
|
| 249 |
+
blocked = {category: any(blocked) for category, blocked in blocked.items()}
|
| 250 |
+
|
| 251 |
+
safety_list = []
|
| 252 |
+
for (category, probability), blocked in zip(ratings.items(), blocked.values()):
|
| 253 |
+
safety_list.append(
|
| 254 |
+
protos.SafetyRating(category=category, probability=probability, blocked=blocked)
|
| 255 |
+
)
|
| 256 |
+
|
| 257 |
+
return safety_list
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
def _join_contents(contents: Iterable[protos.Content]):
|
| 261 |
+
contents = tuple(contents)
|
| 262 |
+
roles = [c.role for c in contents if c.role]
|
| 263 |
+
if roles:
|
| 264 |
+
role = roles[0]
|
| 265 |
+
else:
|
| 266 |
+
role = ""
|
| 267 |
+
|
| 268 |
+
parts = []
|
| 269 |
+
for content in contents:
|
| 270 |
+
parts.extend(content.parts)
|
| 271 |
+
|
| 272 |
+
merged_parts = []
|
| 273 |
+
last = parts[0]
|
| 274 |
+
for part in parts[1:]:
|
| 275 |
+
if "text" in last and "text" in part:
|
| 276 |
+
last = protos.Part(text=last.text + part.text)
|
| 277 |
+
continue
|
| 278 |
+
|
| 279 |
+
# Can we merge the new thing into last?
|
| 280 |
+
# If not, put last in list of parts, and new thing becomes last
|
| 281 |
+
if "executable_code" in last and "executable_code" in part:
|
| 282 |
+
last = protos.Part(
|
| 283 |
+
executable_code=_join_executable_code(last.executable_code, part.executable_code)
|
| 284 |
+
)
|
| 285 |
+
continue
|
| 286 |
+
|
| 287 |
+
if "code_execution_result" in last and "code_execution_result" in part:
|
| 288 |
+
last = protos.Part(
|
| 289 |
+
code_execution_result=_join_code_execution_result(
|
| 290 |
+
last.code_execution_result, part.code_execution_result
|
| 291 |
+
)
|
| 292 |
+
)
|
| 293 |
+
continue
|
| 294 |
+
|
| 295 |
+
merged_parts.append(last)
|
| 296 |
+
last = part
|
| 297 |
+
|
| 298 |
+
merged_parts.append(last)
|
| 299 |
+
|
| 300 |
+
return protos.Content(
|
| 301 |
+
role=role,
|
| 302 |
+
parts=merged_parts,
|
| 303 |
+
)
|
| 304 |
+
|
| 305 |
+
|
| 306 |
+
def _join_executable_code(code_1, code_2):
|
| 307 |
+
return protos.ExecutableCode(language=code_1.language, code=code_1.code + code_2.code)
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
def _join_code_execution_result(result_1, result_2):
|
| 311 |
+
return protos.CodeExecutionResult(
|
| 312 |
+
outcome=result_2.outcome, output=result_1.output + result_2.output
|
| 313 |
+
)
|
| 314 |
+
|
| 315 |
+
|
| 316 |
+
def _join_candidates(candidates: Iterable[protos.Candidate]):
|
| 317 |
+
"""Joins stream chunks of a single candidate."""
|
| 318 |
+
candidates = tuple(candidates)
|
| 319 |
+
|
| 320 |
+
index = candidates[0].index # These should all be the same.
|
| 321 |
+
|
| 322 |
+
return protos.Candidate(
|
| 323 |
+
index=index,
|
| 324 |
+
content=_join_contents([c.content for c in candidates]),
|
| 325 |
+
finish_reason=candidates[-1].finish_reason,
|
| 326 |
+
safety_ratings=_join_safety_ratings_lists([c.safety_ratings for c in candidates]),
|
| 327 |
+
citation_metadata=_join_citation_metadatas([c.citation_metadata for c in candidates]),
|
| 328 |
+
token_count=candidates[-1].token_count,
|
| 329 |
+
)
|
| 330 |
+
|
| 331 |
+
|
| 332 |
+
def _join_candidate_lists(candidate_lists: Iterable[list[protos.Candidate]]):
|
| 333 |
+
"""Joins stream chunks where each chunk is a list of candidate chunks."""
|
| 334 |
+
# Assuming that is a candidate ends, it is no longer returned in the list of
|
| 335 |
+
# candidates and that's why candidates have an index
|
| 336 |
+
candidates = collections.defaultdict(list)
|
| 337 |
+
for candidate_list in candidate_lists:
|
| 338 |
+
for candidate in candidate_list:
|
| 339 |
+
candidates[candidate.index].append(candidate)
|
| 340 |
+
|
| 341 |
+
new_candidates = []
|
| 342 |
+
for index, candidate_parts in sorted(candidates.items()):
|
| 343 |
+
new_candidates.append(_join_candidates(candidate_parts))
|
| 344 |
+
|
| 345 |
+
return new_candidates
|
| 346 |
+
|
| 347 |
+
|
| 348 |
+
def _join_prompt_feedbacks(
|
| 349 |
+
prompt_feedbacks: Iterable[protos.GenerateContentResponse.PromptFeedback],
|
| 350 |
+
):
|
| 351 |
+
# Always return the first prompt feedback.
|
| 352 |
+
return next(iter(prompt_feedbacks))
|
| 353 |
+
|
| 354 |
+
|
| 355 |
+
def _join_chunks(chunks: Iterable[protos.GenerateContentResponse]):
|
| 356 |
+
chunks = tuple(chunks)
|
| 357 |
+
if "usage_metadata" in chunks[-1]:
|
| 358 |
+
usage_metadata = chunks[-1].usage_metadata
|
| 359 |
+
else:
|
| 360 |
+
usage_metadata = None
|
| 361 |
+
|
| 362 |
+
if "model_version" in chunks[-1]:
|
| 363 |
+
model_version = chunks[-1].model_version
|
| 364 |
+
else:
|
| 365 |
+
model_version = None
|
| 366 |
+
|
| 367 |
+
return protos.GenerateContentResponse(
|
| 368 |
+
candidates=_join_candidate_lists(c.candidates for c in chunks),
|
| 369 |
+
prompt_feedback=_join_prompt_feedbacks(c.prompt_feedback for c in chunks),
|
| 370 |
+
usage_metadata=usage_metadata,
|
| 371 |
+
model_version=model_version,
|
| 372 |
+
)
|
| 373 |
+
|
| 374 |
+
|
| 375 |
+
_INCOMPLETE_ITERATION_MESSAGE = """\
|
| 376 |
+
Please let the response complete iteration before accessing the final accumulated
|
| 377 |
+
attributes (or call `response.resolve()`)"""
|
| 378 |
+
|
| 379 |
+
|
| 380 |
+
class BaseGenerateContentResponse:
|
| 381 |
+
def __init__(
|
| 382 |
+
self,
|
| 383 |
+
done: bool,
|
| 384 |
+
iterator: (
|
| 385 |
+
None
|
| 386 |
+
| Iterable[protos.GenerateContentResponse]
|
| 387 |
+
| AsyncIterable[protos.GenerateContentResponse]
|
| 388 |
+
),
|
| 389 |
+
result: protos.GenerateContentResponse,
|
| 390 |
+
chunks: Iterable[protos.GenerateContentResponse] | None = None,
|
| 391 |
+
):
|
| 392 |
+
self._done = done
|
| 393 |
+
self._iterator = iterator
|
| 394 |
+
self._result = result
|
| 395 |
+
if chunks is None:
|
| 396 |
+
self._chunks = [result]
|
| 397 |
+
else:
|
| 398 |
+
self._chunks = list(chunks)
|
| 399 |
+
if result.prompt_feedback.block_reason:
|
| 400 |
+
self._error = BlockedPromptException(result)
|
| 401 |
+
else:
|
| 402 |
+
self._error = None
|
| 403 |
+
|
| 404 |
+
def to_dict(self):
|
| 405 |
+
"""Returns the result as a JSON-compatible dict.
|
| 406 |
+
|
| 407 |
+
Note: This doesn't capture the iterator state when streaming, it only captures the accumulated
|
| 408 |
+
`GenerateContentResponse` fields.
|
| 409 |
+
|
| 410 |
+
>>> import json
|
| 411 |
+
>>> response = model.generate_content('Hello?')
|
| 412 |
+
>>> json.dumps(response.to_dict())
|
| 413 |
+
"""
|
| 414 |
+
return type(self._result).to_dict(self._result)
|
| 415 |
+
|
| 416 |
+
@property
|
| 417 |
+
def candidates(self):
|
| 418 |
+
"""The list of candidate responses.
|
| 419 |
+
|
| 420 |
+
Raises:
|
| 421 |
+
IncompleteIterationError: With `stream=True` if iteration over the stream was not completed.
|
| 422 |
+
"""
|
| 423 |
+
if not self._done:
|
| 424 |
+
raise IncompleteIterationError(_INCOMPLETE_ITERATION_MESSAGE)
|
| 425 |
+
return self._result.candidates
|
| 426 |
+
|
| 427 |
+
@property
|
| 428 |
+
def parts(self):
|
| 429 |
+
"""A quick accessor equivalent to `self.candidates[0].content.parts`
|
| 430 |
+
|
| 431 |
+
Raises:
|
| 432 |
+
ValueError: If the candidate list does not contain exactly one candidate.
|
| 433 |
+
"""
|
| 434 |
+
candidates = self.candidates
|
| 435 |
+
if not candidates:
|
| 436 |
+
msg = (
|
| 437 |
+
"Invalid operation: The `response.parts` quick accessor requires a single candidate, "
|
| 438 |
+
"but but `response.candidates` is empty."
|
| 439 |
+
)
|
| 440 |
+
if self.prompt_feedback:
|
| 441 |
+
raise ValueError(
|
| 442 |
+
msg + "\nThis appears to be caused by a blocked prompt, "
|
| 443 |
+
f"see `response.prompt_feedback`: {self.prompt_feedback}"
|
| 444 |
+
)
|
| 445 |
+
else:
|
| 446 |
+
raise ValueError(msg)
|
| 447 |
+
|
| 448 |
+
if len(candidates) > 1:
|
| 449 |
+
raise ValueError(
|
| 450 |
+
"Invalid operation: The `response.parts` quick accessor retrieves the parts for a single candidate. "
|
| 451 |
+
"This response contains multiple candidates, please use `result.candidates[index].text`."
|
| 452 |
+
)
|
| 453 |
+
parts = candidates[0].content.parts
|
| 454 |
+
return parts
|
| 455 |
+
|
| 456 |
+
@property
|
| 457 |
+
def text(self):
|
| 458 |
+
"""A quick accessor equivalent to `self.candidates[0].content.parts[0].text`
|
| 459 |
+
|
| 460 |
+
Raises:
|
| 461 |
+
ValueError: If the candidate list or parts list does not contain exactly one entry.
|
| 462 |
+
"""
|
| 463 |
+
parts = self.parts
|
| 464 |
+
if not parts:
|
| 465 |
+
candidate = self.candidates[0]
|
| 466 |
+
|
| 467 |
+
fr = candidate.finish_reason
|
| 468 |
+
FinishReason = protos.Candidate.FinishReason
|
| 469 |
+
|
| 470 |
+
msg = (
|
| 471 |
+
"Invalid operation: The `response.text` quick accessor requires the response to contain a valid "
|
| 472 |
+
"`Part`, but none were returned. The candidate's "
|
| 473 |
+
f"[finish_reason](https://ai.google.dev/api/generate-content#finishreason) is {fr}."
|
| 474 |
+
)
|
| 475 |
+
|
| 476 |
+
if fr is FinishReason.FINISH_REASON_UNSPECIFIED:
|
| 477 |
+
raise ValueError(msg)
|
| 478 |
+
elif fr is FinishReason.STOP:
|
| 479 |
+
raise ValueError(msg)
|
| 480 |
+
elif fr is FinishReason.MAX_TOKENS:
|
| 481 |
+
raise ValueError(msg)
|
| 482 |
+
elif fr is FinishReason.SAFETY:
|
| 483 |
+
raise ValueError(
|
| 484 |
+
msg + f" The candidate's safety_ratings are: {candidate.safety_ratings}.",
|
| 485 |
+
candidate.safety_ratings,
|
| 486 |
+
)
|
| 487 |
+
elif fr is FinishReason.RECITATION:
|
| 488 |
+
raise ValueError(
|
| 489 |
+
msg + " Meaning that the model was reciting from copyrighted material."
|
| 490 |
+
)
|
| 491 |
+
elif fr is FinishReason.LANGUAGE:
|
| 492 |
+
raise ValueError(msg + " Meaning the response was using an unsupported language.")
|
| 493 |
+
elif fr is FinishReason.OTHER:
|
| 494 |
+
raise ValueError(msg)
|
| 495 |
+
elif fr is FinishReason.BLOCKLIST:
|
| 496 |
+
raise ValueError(msg)
|
| 497 |
+
elif fr is FinishReason.PROHIBITED_CONTENT:
|
| 498 |
+
raise ValueError(msg)
|
| 499 |
+
elif fr is FinishReason.SPII:
|
| 500 |
+
raise ValueError(msg + " SPII - Sensitive Personally Identifiable Information.")
|
| 501 |
+
elif fr is FinishReason.MALFORMED_FUNCTION_CALL:
|
| 502 |
+
raise ValueError(
|
| 503 |
+
msg + " Meaning that model generated a `FunctionCall` that was invalid. "
|
| 504 |
+
"Setting the "
|
| 505 |
+
"[Function calling mode](https://ai.google.dev/gemini-api/docs/function-calling#function_calling_mode) "
|
| 506 |
+
"to `ANY` can fix this because it enables constrained decoding."
|
| 507 |
+
)
|
| 508 |
+
else:
|
| 509 |
+
raise ValueError(msg)
|
| 510 |
+
|
| 511 |
+
texts = []
|
| 512 |
+
for part in parts:
|
| 513 |
+
if "text" in part:
|
| 514 |
+
texts.append(part.text)
|
| 515 |
+
continue
|
| 516 |
+
if "executable_code" in part:
|
| 517 |
+
language = part.executable_code.language.name.lower()
|
| 518 |
+
if language == "language_unspecified":
|
| 519 |
+
language = ""
|
| 520 |
+
else:
|
| 521 |
+
language = f" {language}"
|
| 522 |
+
texts.extend([f"```{language}", part.executable_code.code.lstrip("\n"), "```"])
|
| 523 |
+
continue
|
| 524 |
+
if "code_execution_result" in part:
|
| 525 |
+
outcome_result = part.code_execution_result.outcome.name.lower().replace(
|
| 526 |
+
"outcome_", ""
|
| 527 |
+
)
|
| 528 |
+
if outcome_result == "ok" or outcome_result == "unspecified":
|
| 529 |
+
outcome_result = ""
|
| 530 |
+
else:
|
| 531 |
+
outcome_result = f" {outcome_result}"
|
| 532 |
+
texts.extend([f"```{outcome_result}", part.code_execution_result.output, "```"])
|
| 533 |
+
continue
|
| 534 |
+
|
| 535 |
+
part_type = protos.Part.pb(part).whichOneof("data")
|
| 536 |
+
raise ValueError(f"Could not convert `part.{part_type}` to text.")
|
| 537 |
+
|
| 538 |
+
return "\n".join(texts)
|
| 539 |
+
|
| 540 |
+
@property
|
| 541 |
+
def prompt_feedback(self):
|
| 542 |
+
return self._result.prompt_feedback
|
| 543 |
+
|
| 544 |
+
@property
|
| 545 |
+
def usage_metadata(self):
|
| 546 |
+
return self._result.usage_metadata
|
| 547 |
+
|
| 548 |
+
@property
|
| 549 |
+
def model_version(self):
|
| 550 |
+
return self._result.model_version
|
| 551 |
+
|
| 552 |
+
def __str__(self) -> str:
|
| 553 |
+
if self._done:
|
| 554 |
+
_iterator = "None"
|
| 555 |
+
else:
|
| 556 |
+
_iterator = f"<{self._iterator.__class__.__name__}>"
|
| 557 |
+
|
| 558 |
+
as_dict = type(self._result).to_dict(
|
| 559 |
+
self._result, use_integers_for_enums=False, including_default_value_fields=False
|
| 560 |
+
)
|
| 561 |
+
json_str = json.dumps(as_dict, indent=2)
|
| 562 |
+
|
| 563 |
+
_result = f"protos.GenerateContentResponse({json_str})"
|
| 564 |
+
_result = _result.replace("\n", "\n ")
|
| 565 |
+
|
| 566 |
+
if self._error:
|
| 567 |
+
|
| 568 |
+
_error = f",\nerror={repr(self._error)}"
|
| 569 |
+
else:
|
| 570 |
+
_error = ""
|
| 571 |
+
|
| 572 |
+
return (
|
| 573 |
+
textwrap.dedent(
|
| 574 |
+
f"""\
|
| 575 |
+
response:
|
| 576 |
+
{type(self).__name__}(
|
| 577 |
+
done={self._done},
|
| 578 |
+
iterator={_iterator},
|
| 579 |
+
result={_result},
|
| 580 |
+
)"""
|
| 581 |
+
)
|
| 582 |
+
+ _error
|
| 583 |
+
)
|
| 584 |
+
|
| 585 |
+
__repr__ = __str__
|
| 586 |
+
|
| 587 |
+
|
| 588 |
+
@contextlib.contextmanager
|
| 589 |
+
def rewrite_stream_error():
|
| 590 |
+
try:
|
| 591 |
+
yield
|
| 592 |
+
except (google.protobuf.json_format.ParseError, AttributeError) as e:
|
| 593 |
+
raise google.api_core.exceptions.BadRequest(
|
| 594 |
+
"Unknown error trying to retrieve streaming response. "
|
| 595 |
+
"Please retry with `stream=False` for more details."
|
| 596 |
+
)
|
| 597 |
+
|
| 598 |
+
|
| 599 |
+
GENERATE_CONTENT_RESPONSE_DOC = """Instances of this class manage the response of the `generate_content` method.
|
| 600 |
+
|
| 601 |
+
These are returned by `GenerativeModel.generate_content` and `ChatSession.send_message`.
|
| 602 |
+
This object is based on the low level `protos.GenerateContentResponse` class which just has `prompt_feedback`
|
| 603 |
+
and `candidates` attributes. This class adds several quick accessors for common use cases.
|
| 604 |
+
|
| 605 |
+
The same object type is returned for both `stream=True/False`.
|
| 606 |
+
|
| 607 |
+
### Streaming
|
| 608 |
+
|
| 609 |
+
When you pass `stream=True` to `GenerativeModel.generate_content` or `ChatSession.send_message`,
|
| 610 |
+
iterate over this object to receive chunks of the response:
|
| 611 |
+
|
| 612 |
+
```
|
| 613 |
+
response = model.generate_content(..., stream=True):
|
| 614 |
+
for chunk in response:
|
| 615 |
+
print(chunk.text)
|
| 616 |
+
```
|
| 617 |
+
|
| 618 |
+
`GenerateContentResponse.prompt_feedback` is available immediately but
|
| 619 |
+
`GenerateContentResponse.candidates`, and all the attributes derived from them (`.text`, `.parts`),
|
| 620 |
+
are only available after the iteration is complete.
|
| 621 |
+
"""
|
| 622 |
+
|
| 623 |
+
ASYNC_GENERATE_CONTENT_RESPONSE_DOC = (
|
| 624 |
+
"""This is the async version of `genai.GenerateContentResponse`."""
|
| 625 |
+
)
|
| 626 |
+
|
| 627 |
+
|
| 628 |
+
@string_utils.set_doc(GENERATE_CONTENT_RESPONSE_DOC)
|
| 629 |
+
class GenerateContentResponse(BaseGenerateContentResponse):
|
| 630 |
+
@classmethod
|
| 631 |
+
def from_iterator(cls, iterator: Iterable[protos.GenerateContentResponse]):
|
| 632 |
+
iterator = iter(iterator)
|
| 633 |
+
with rewrite_stream_error():
|
| 634 |
+
response = next(iterator)
|
| 635 |
+
|
| 636 |
+
return cls(
|
| 637 |
+
done=False,
|
| 638 |
+
iterator=iterator,
|
| 639 |
+
result=response,
|
| 640 |
+
)
|
| 641 |
+
|
| 642 |
+
@classmethod
|
| 643 |
+
def from_response(cls, response: protos.GenerateContentResponse):
|
| 644 |
+
return cls(
|
| 645 |
+
done=True,
|
| 646 |
+
iterator=None,
|
| 647 |
+
result=response,
|
| 648 |
+
)
|
| 649 |
+
|
| 650 |
+
def __iter__(self):
|
| 651 |
+
# This is not thread safe.
|
| 652 |
+
if self._done:
|
| 653 |
+
for chunk in self._chunks:
|
| 654 |
+
yield GenerateContentResponse.from_response(chunk)
|
| 655 |
+
return
|
| 656 |
+
|
| 657 |
+
# Always have the next chunk available.
|
| 658 |
+
if len(self._chunks) == 0:
|
| 659 |
+
self._chunks.append(next(self._iterator))
|
| 660 |
+
|
| 661 |
+
for n in itertools.count():
|
| 662 |
+
if self._error:
|
| 663 |
+
raise self._error
|
| 664 |
+
|
| 665 |
+
if n >= len(self._chunks) - 1:
|
| 666 |
+
# Look ahead for a new item, so that you know the stream is done
|
| 667 |
+
# when you yield the last item.
|
| 668 |
+
if self._done:
|
| 669 |
+
return
|
| 670 |
+
|
| 671 |
+
try:
|
| 672 |
+
item = next(self._iterator)
|
| 673 |
+
except StopIteration:
|
| 674 |
+
self._done = True
|
| 675 |
+
except Exception as e:
|
| 676 |
+
self._error = e
|
| 677 |
+
self._done = True
|
| 678 |
+
else:
|
| 679 |
+
self._chunks.append(item)
|
| 680 |
+
self._result = _join_chunks([self._result, item])
|
| 681 |
+
|
| 682 |
+
item = self._chunks[n]
|
| 683 |
+
|
| 684 |
+
item = GenerateContentResponse.from_response(item)
|
| 685 |
+
yield item
|
| 686 |
+
|
| 687 |
+
def resolve(self):
|
| 688 |
+
if self._done:
|
| 689 |
+
return
|
| 690 |
+
|
| 691 |
+
for _ in self:
|
| 692 |
+
pass
|
| 693 |
+
|
| 694 |
+
|
| 695 |
+
@string_utils.set_doc(ASYNC_GENERATE_CONTENT_RESPONSE_DOC)
|
| 696 |
+
class AsyncGenerateContentResponse(BaseGenerateContentResponse):
|
| 697 |
+
@classmethod
|
| 698 |
+
async def from_aiterator(cls, iterator: AsyncIterable[protos.GenerateContentResponse]):
|
| 699 |
+
iterator = aiter(iterator) # type: ignore
|
| 700 |
+
with rewrite_stream_error():
|
| 701 |
+
response = await anext(iterator) # type: ignore
|
| 702 |
+
|
| 703 |
+
return cls(
|
| 704 |
+
done=False,
|
| 705 |
+
iterator=iterator,
|
| 706 |
+
result=response,
|
| 707 |
+
)
|
| 708 |
+
|
| 709 |
+
@classmethod
|
| 710 |
+
def from_response(cls, response: protos.GenerateContentResponse):
|
| 711 |
+
return cls(
|
| 712 |
+
done=True,
|
| 713 |
+
iterator=None,
|
| 714 |
+
result=response,
|
| 715 |
+
)
|
| 716 |
+
|
| 717 |
+
async def __aiter__(self):
|
| 718 |
+
# This is not thread safe.
|
| 719 |
+
if self._done:
|
| 720 |
+
for chunk in self._chunks:
|
| 721 |
+
yield GenerateContentResponse.from_response(chunk)
|
| 722 |
+
return
|
| 723 |
+
|
| 724 |
+
# Always have the next chunk available.
|
| 725 |
+
if len(self._chunks) == 0:
|
| 726 |
+
self._chunks.append(await anext(self._iterator)) # type: ignore
|
| 727 |
+
|
| 728 |
+
for n in itertools.count():
|
| 729 |
+
if self._error:
|
| 730 |
+
raise self._error
|
| 731 |
+
|
| 732 |
+
if n >= len(self._chunks) - 1:
|
| 733 |
+
# Look ahead for a new item, so that you know the stream is done
|
| 734 |
+
# when you yield the last item.
|
| 735 |
+
if self._done:
|
| 736 |
+
return
|
| 737 |
+
|
| 738 |
+
try:
|
| 739 |
+
item = await anext(self._iterator) # type: ignore
|
| 740 |
+
except StopAsyncIteration:
|
| 741 |
+
self._done = True
|
| 742 |
+
except Exception as e:
|
| 743 |
+
self._error = e
|
| 744 |
+
self._done = True
|
| 745 |
+
else:
|
| 746 |
+
self._chunks.append(item)
|
| 747 |
+
self._result = _join_chunks([self._result, item])
|
| 748 |
+
|
| 749 |
+
item = self._chunks[n]
|
| 750 |
+
|
| 751 |
+
item = GenerateContentResponse.from_response(item)
|
| 752 |
+
yield item
|
| 753 |
+
|
| 754 |
+
async def resolve(self):
|
| 755 |
+
if self._done:
|
| 756 |
+
return
|
| 757 |
+
|
| 758 |
+
async for _ in self:
|
| 759 |
+
pass
|
.venv/lib/python3.11/site-packages/google/generativeai/types/image_types/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (285 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/google/generativeai/types/image_types/__pycache__/_image_types.cpython-311.pyc
ADDED
|
Binary file (20.5 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/google/generativeai/types/image_types/_image_types.py
ADDED
|
@@ -0,0 +1,440 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import base64
|
| 4 |
+
import io
|
| 5 |
+
import json
|
| 6 |
+
import mimetypes
|
| 7 |
+
import os
|
| 8 |
+
import pathlib
|
| 9 |
+
import typing
|
| 10 |
+
from typing import Any, Dict, Optional, Union
|
| 11 |
+
|
| 12 |
+
import httplib2
|
| 13 |
+
import threading
|
| 14 |
+
|
| 15 |
+
import googleapiclient.http
|
| 16 |
+
|
| 17 |
+
from google.generativeai import protos
|
| 18 |
+
|
| 19 |
+
# pylint: disable=g-import-not-at-top
|
| 20 |
+
if typing.TYPE_CHECKING:
|
| 21 |
+
import PIL.Image
|
| 22 |
+
import PIL.ImageFile
|
| 23 |
+
import IPython.display
|
| 24 |
+
|
| 25 |
+
_IMPORTED_IMAGE_TYPES = (PIL.Image.Image, IPython.display.Image)
|
| 26 |
+
ImageType = Union["Image", PIL.Image.Image, IPython.display.Image]
|
| 27 |
+
else:
|
| 28 |
+
_IMPORTED_IMAGE_TYPES = ()
|
| 29 |
+
try:
|
| 30 |
+
import PIL.Image
|
| 31 |
+
import PIL.ImageFile
|
| 32 |
+
|
| 33 |
+
_IMPORTED_IMAGE_TYPES = _IMPORTED_IMAGE_TYPES + (PIL.Image.Image,)
|
| 34 |
+
except ImportError:
|
| 35 |
+
PIL = None
|
| 36 |
+
|
| 37 |
+
try:
|
| 38 |
+
import IPython.display
|
| 39 |
+
|
| 40 |
+
IMAGE_TYPES = _IMPORTED_IMAGE_TYPES + (IPython.display.Image,)
|
| 41 |
+
except ImportError:
|
| 42 |
+
IPython = None
|
| 43 |
+
|
| 44 |
+
ImageType = Union["Image", "PIL.Image.Image", "IPython.display.Image"]
|
| 45 |
+
# pylint: enable=g-import-not-at-top
|
| 46 |
+
|
| 47 |
+
__all__ = [
|
| 48 |
+
"Image",
|
| 49 |
+
"GeneratedImage",
|
| 50 |
+
"to_image",
|
| 51 |
+
"check_watermark",
|
| 52 |
+
"CheckWatermarkResult",
|
| 53 |
+
"ImageType",
|
| 54 |
+
"Video",
|
| 55 |
+
"GeneratedVideo",
|
| 56 |
+
]
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def _pil_to_blob(image: PIL.Image.Image) -> protos.Blob:
|
| 60 |
+
# If the image is a local file, return a file-based blob without any modification.
|
| 61 |
+
# Otherwise, return a lossless WebP blob (same quality with optimized size).
|
| 62 |
+
def file_blob(image: PIL.Image.Image) -> Union[protos.Blob, None]:
|
| 63 |
+
if not isinstance(image, PIL.ImageFile.ImageFile) or image.filename is None:
|
| 64 |
+
return None
|
| 65 |
+
filename = str(image.filename)
|
| 66 |
+
if not pathlib.Path(filename).is_file():
|
| 67 |
+
return None
|
| 68 |
+
|
| 69 |
+
mime_type = image.get_format_mimetype()
|
| 70 |
+
image_bytes = pathlib.Path(filename).read_bytes()
|
| 71 |
+
|
| 72 |
+
return protos.Blob(mime_type=mime_type, data=image_bytes)
|
| 73 |
+
|
| 74 |
+
def webp_blob(image: PIL.Image.Image) -> protos.Blob:
|
| 75 |
+
# Reference: https://pillow.readthedocs.io/en/stable/handbook/image-file-formats.html#webp
|
| 76 |
+
image_io = io.BytesIO()
|
| 77 |
+
image.save(image_io, format="webp", lossless=True)
|
| 78 |
+
image_io.seek(0)
|
| 79 |
+
|
| 80 |
+
mime_type = "image/webp"
|
| 81 |
+
image_bytes = image_io.read()
|
| 82 |
+
|
| 83 |
+
return protos.Blob(mime_type=mime_type, data=image_bytes)
|
| 84 |
+
|
| 85 |
+
return file_blob(image) or webp_blob(image)
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def image_to_blob(image: ImageType) -> protos.Blob:
|
| 89 |
+
if PIL is not None:
|
| 90 |
+
if isinstance(image, PIL.Image.Image):
|
| 91 |
+
return _pil_to_blob(image)
|
| 92 |
+
|
| 93 |
+
if IPython is not None:
|
| 94 |
+
if isinstance(image, IPython.display.Image):
|
| 95 |
+
name = image.filename
|
| 96 |
+
if name is None:
|
| 97 |
+
raise ValueError(
|
| 98 |
+
"Conversion failed. The `IPython.display.Image` can only be converted if "
|
| 99 |
+
"it is constructed from a local file. Please ensure you are using the format: Image(filename='...')."
|
| 100 |
+
)
|
| 101 |
+
mime_type, _ = mimetypes.guess_type(name)
|
| 102 |
+
if mime_type is None:
|
| 103 |
+
mime_type = "image/unknown"
|
| 104 |
+
|
| 105 |
+
return protos.Blob(mime_type=mime_type, data=image.data)
|
| 106 |
+
|
| 107 |
+
if isinstance(image, Image):
|
| 108 |
+
return protos.Blob(mime_type=image._mime_type, data=image._image_bytes)
|
| 109 |
+
|
| 110 |
+
raise TypeError(
|
| 111 |
+
"Image conversion failed. The input was expected to be of type `Image` "
|
| 112 |
+
"(either `PIL.Image.Image` or `IPython.display.Image`).\n"
|
| 113 |
+
f"However, received an object of type: {type(image)}.\n"
|
| 114 |
+
f"Object Value: {image}"
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
class CheckWatermarkResult:
|
| 119 |
+
def __init__(self, predictions):
|
| 120 |
+
self._predictions = predictions
|
| 121 |
+
|
| 122 |
+
@property
|
| 123 |
+
def decision(self):
|
| 124 |
+
return self._predictions[0]["decision"]
|
| 125 |
+
|
| 126 |
+
def __str__(self):
|
| 127 |
+
return f"CheckWatermarkResult([{{'decision': {self.decision!r}}}])"
|
| 128 |
+
|
| 129 |
+
def __bool__(self):
|
| 130 |
+
decision = self.decision
|
| 131 |
+
if decision == "ACCEPT":
|
| 132 |
+
return True
|
| 133 |
+
elif decision == "REJECT":
|
| 134 |
+
return False
|
| 135 |
+
else:
|
| 136 |
+
raise ValueError("Unrecognized result")
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
def to_image(img: Union[pathlib.Path, ImageType]) -> Image:
|
| 140 |
+
if isinstance(img, Image):
|
| 141 |
+
pass
|
| 142 |
+
elif isinstance(img, pathlib.Path):
|
| 143 |
+
img = Image.load_from_file(img)
|
| 144 |
+
elif IPython.display is not None and isinstance(img, IPython.display.Image):
|
| 145 |
+
img = Image(image_bytes=img.data)
|
| 146 |
+
elif PIL.Image is not None and isinstance(img, PIL.Image.Image):
|
| 147 |
+
blob = _pil_to_blob(img)
|
| 148 |
+
img = Image(image_bytes=blob.data)
|
| 149 |
+
elif isinstance(img, protos.Blob):
|
| 150 |
+
img = Image(image_bytes=img.data)
|
| 151 |
+
else:
|
| 152 |
+
raise TypeError(
|
| 153 |
+
f"Not implemented: Could not convert a {type(img)} into `Image`\n {img=}"
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
return img
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
class Image:
|
| 160 |
+
"""Image."""
|
| 161 |
+
|
| 162 |
+
__module__ = "vertexai.vision_models"
|
| 163 |
+
|
| 164 |
+
_loaded_bytes: Optional[bytes] = None
|
| 165 |
+
_loaded_image: Optional["PIL.Image.Image"] = None
|
| 166 |
+
|
| 167 |
+
def __init__(
|
| 168 |
+
self,
|
| 169 |
+
image_bytes: Optional[bytes],
|
| 170 |
+
):
|
| 171 |
+
"""Creates an `Image` object.
|
| 172 |
+
|
| 173 |
+
Args:
|
| 174 |
+
image_bytes: Image file bytes. Image can be in PNG or JPEG format.
|
| 175 |
+
"""
|
| 176 |
+
self._image_bytes = image_bytes
|
| 177 |
+
|
| 178 |
+
@staticmethod
|
| 179 |
+
def load_from_file(location: os.PathLike) -> "Image":
|
| 180 |
+
"""Loads image from local file.
|
| 181 |
+
|
| 182 |
+
Args:
|
| 183 |
+
location: Local path from where to load
|
| 184 |
+
the image.
|
| 185 |
+
|
| 186 |
+
Returns:
|
| 187 |
+
Loaded image as an `Image` object.
|
| 188 |
+
"""
|
| 189 |
+
# Load image from local path
|
| 190 |
+
image_bytes = pathlib.Path(location).read_bytes()
|
| 191 |
+
image = Image(image_bytes=image_bytes)
|
| 192 |
+
return image
|
| 193 |
+
|
| 194 |
+
@property
|
| 195 |
+
def _image_bytes(self) -> bytes:
|
| 196 |
+
return self._loaded_bytes
|
| 197 |
+
|
| 198 |
+
@_image_bytes.setter
|
| 199 |
+
def _image_bytes(self, value: bytes):
|
| 200 |
+
self._loaded_bytes = value
|
| 201 |
+
|
| 202 |
+
@property
|
| 203 |
+
def _pil_image(self) -> "PIL.Image.Image": # type: ignore
|
| 204 |
+
if self._loaded_image is None:
|
| 205 |
+
if not PIL:
|
| 206 |
+
raise RuntimeError(
|
| 207 |
+
"The PIL module is not available. Please install the Pillow package."
|
| 208 |
+
)
|
| 209 |
+
self._loaded_image = PIL.Image.open(io.BytesIO(self._image_bytes))
|
| 210 |
+
return self._loaded_image
|
| 211 |
+
|
| 212 |
+
@property
|
| 213 |
+
def _size(self):
|
| 214 |
+
return self._pil_image.size
|
| 215 |
+
|
| 216 |
+
@property
|
| 217 |
+
def _mime_type(self) -> str:
|
| 218 |
+
"""Returns the MIME type of the image."""
|
| 219 |
+
import PIL
|
| 220 |
+
|
| 221 |
+
return PIL.Image.MIME.get(self._pil_image.format, "image/jpeg")
|
| 222 |
+
|
| 223 |
+
def show(self):
|
| 224 |
+
"""Shows the image.
|
| 225 |
+
|
| 226 |
+
This method only works when in a notebook environment.
|
| 227 |
+
"""
|
| 228 |
+
if PIL and IPython:
|
| 229 |
+
IPython.display.display(self._pil_image)
|
| 230 |
+
|
| 231 |
+
def save(self, location: str):
|
| 232 |
+
"""Saves image to a file.
|
| 233 |
+
|
| 234 |
+
Args:
|
| 235 |
+
location: Local path where to save the image.
|
| 236 |
+
"""
|
| 237 |
+
pathlib.Path(location).write_bytes(self._image_bytes)
|
| 238 |
+
|
| 239 |
+
def _as_base64_string(self) -> str:
|
| 240 |
+
"""Encodes image using the base64 encoding.
|
| 241 |
+
|
| 242 |
+
Returns:
|
| 243 |
+
Base64 encoding of the image as a string.
|
| 244 |
+
"""
|
| 245 |
+
# ! b64encode returns `bytes` object, not `str`.
|
| 246 |
+
# We need to convert `bytes` to `str`, otherwise we get service error:
|
| 247 |
+
# "received initial metadata size exceeds limit"
|
| 248 |
+
return base64.b64encode(self._image_bytes).decode("ascii")
|
| 249 |
+
|
| 250 |
+
def _repr_png_(self):
|
| 251 |
+
return self._pil_image._repr_png_() # type:ignore
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
IMAGE_TYPES = _IMPORTED_IMAGE_TYPES + (Image,)
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
_EXIF_USER_COMMENT_TAG_IDX = 0x9286
|
| 258 |
+
_IMAGE_GENERATION_PARAMETERS_EXIF_KEY = (
|
| 259 |
+
"google.cloud.vertexai.image_generation.image_generation_parameters"
|
| 260 |
+
)
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
class GeneratedImage(Image):
|
| 264 |
+
"""Generated image."""
|
| 265 |
+
|
| 266 |
+
__module__ = "google.generativeai"
|
| 267 |
+
|
| 268 |
+
def __init__(
|
| 269 |
+
self,
|
| 270 |
+
image_bytes: Optional[bytes],
|
| 271 |
+
generation_parameters: Dict[str, Any],
|
| 272 |
+
):
|
| 273 |
+
"""Creates a `GeneratedImage` object.
|
| 274 |
+
|
| 275 |
+
Args:
|
| 276 |
+
image_bytes: Image file bytes. Image can be in PNG or JPEG format.
|
| 277 |
+
generation_parameters: Image generation parameter values.
|
| 278 |
+
"""
|
| 279 |
+
super().__init__(image_bytes=image_bytes)
|
| 280 |
+
self._generation_parameters = generation_parameters
|
| 281 |
+
|
| 282 |
+
@property
|
| 283 |
+
def generation_parameters(self):
|
| 284 |
+
"""Image generation parameters as a dictionary."""
|
| 285 |
+
return self._generation_parameters
|
| 286 |
+
|
| 287 |
+
@staticmethod
|
| 288 |
+
def load_from_file(location: os.PathLike) -> "GeneratedImage":
|
| 289 |
+
"""Loads image from file.
|
| 290 |
+
|
| 291 |
+
Args:
|
| 292 |
+
location: Local path from where to load the image.
|
| 293 |
+
|
| 294 |
+
Returns:
|
| 295 |
+
Loaded image as a `GeneratedImage` object.
|
| 296 |
+
"""
|
| 297 |
+
base_image = Image.load_from_file(location=location)
|
| 298 |
+
exif = base_image._pil_image.getexif() # pylint: disable=protected-access
|
| 299 |
+
exif_comment_dict = json.loads(exif[_EXIF_USER_COMMENT_TAG_IDX])
|
| 300 |
+
generation_parameters = exif_comment_dict[_IMAGE_GENERATION_PARAMETERS_EXIF_KEY]
|
| 301 |
+
return GeneratedImage(
|
| 302 |
+
image_bytes=base_image._image_bytes, # pylint: disable=protected-access
|
| 303 |
+
generation_parameters=generation_parameters,
|
| 304 |
+
)
|
| 305 |
+
|
| 306 |
+
def save(self, location: str, include_generation_parameters: bool = True):
|
| 307 |
+
"""Saves image to a file.
|
| 308 |
+
|
| 309 |
+
Args:
|
| 310 |
+
location: Local path where to save the image.
|
| 311 |
+
include_generation_parameters: Whether to include the image
|
| 312 |
+
generation parameters in the image's EXIF metadata.
|
| 313 |
+
"""
|
| 314 |
+
if include_generation_parameters:
|
| 315 |
+
if not self._generation_parameters:
|
| 316 |
+
raise ValueError("Image does not have generation parameters.")
|
| 317 |
+
if not PIL:
|
| 318 |
+
raise ValueError("The PIL module is required for saving generation parameters.")
|
| 319 |
+
|
| 320 |
+
exif = self._pil_image.getexif()
|
| 321 |
+
exif[_EXIF_USER_COMMENT_TAG_IDX] = json.dumps(
|
| 322 |
+
{_IMAGE_GENERATION_PARAMETERS_EXIF_KEY: self._generation_parameters}
|
| 323 |
+
)
|
| 324 |
+
self._pil_image.save(location, exif=exif)
|
| 325 |
+
else:
|
| 326 |
+
super().save(location=location)
|
| 327 |
+
|
| 328 |
+
|
| 329 |
+
class Video:
|
| 330 |
+
"""Video."""
|
| 331 |
+
|
| 332 |
+
_loaded_bytes: Optional[bytes] = None
|
| 333 |
+
|
| 334 |
+
def __init__(
|
| 335 |
+
self,
|
| 336 |
+
*,
|
| 337 |
+
video_bytes: Optional[bytes] = None,
|
| 338 |
+
mime_type: Optional[str] = None,
|
| 339 |
+
):
|
| 340 |
+
"""Creates a `Video` object.
|
| 341 |
+
Args:
|
| 342 |
+
video_bytes: Video file bytes. Video can be in AVI, FLV, MKV, MOV,
|
| 343 |
+
MP4, MPEG, MPG, WEBM, and WMV formats.
|
| 344 |
+
"""
|
| 345 |
+
self._video_bytes = video_bytes
|
| 346 |
+
self._mime_type = mime_type
|
| 347 |
+
|
| 348 |
+
def _ipython_display_(self):
|
| 349 |
+
if IPython.display is None:
|
| 350 |
+
return
|
| 351 |
+
|
| 352 |
+
IPython.display.display(
|
| 353 |
+
IPython.display.Video(data=self._video_bytes, mimetype=self._mime_type, embed=True)
|
| 354 |
+
)
|
| 355 |
+
|
| 356 |
+
@staticmethod
|
| 357 |
+
def load_from_file(location: str) -> "Video":
|
| 358 |
+
"""Loads video from local file.
|
| 359 |
+
Args:
|
| 360 |
+
location: Local path from where to load the video.
|
| 361 |
+
Returns:
|
| 362 |
+
Loaded video as an `Video` object.
|
| 363 |
+
"""
|
| 364 |
+
# Load video from local path
|
| 365 |
+
video_bytes = pathlib.Path(location).read_bytes()
|
| 366 |
+
mimetypes.guess_type(location)
|
| 367 |
+
video = Video(video_bytes=video_bytes, mime_type=mimetypes.guess_type(location))
|
| 368 |
+
return video
|
| 369 |
+
|
| 370 |
+
@property
|
| 371 |
+
def _video_bytes(self) -> bytes:
|
| 372 |
+
return self._loaded_bytes
|
| 373 |
+
|
| 374 |
+
@_video_bytes.setter
|
| 375 |
+
def _video_bytes(self, value: bytes):
|
| 376 |
+
self._loaded_bytes = value
|
| 377 |
+
|
| 378 |
+
@property
|
| 379 |
+
def mime_type(self) -> str:
|
| 380 |
+
"""Returns the MIME type of the video."""
|
| 381 |
+
return self._mime_type
|
| 382 |
+
|
| 383 |
+
def save(self, location: str):
|
| 384 |
+
"""Saves video to a file.
|
| 385 |
+
Args:
|
| 386 |
+
location: Local path where to save the video.
|
| 387 |
+
"""
|
| 388 |
+
pathlib.Path(location).write_bytes(self._video_bytes)
|
| 389 |
+
|
| 390 |
+
def _as_base64_string(self) -> str:
|
| 391 |
+
"""Encodes video using the base64 encoding.
|
| 392 |
+
Returns:
|
| 393 |
+
Base64 encoding of the video as a string.
|
| 394 |
+
"""
|
| 395 |
+
# ! b64encode returns `bytes` object, not `str`.
|
| 396 |
+
# We need to convert `bytes` to `str`, otherwise we get service error:
|
| 397 |
+
# "received initial metadata size exceeds limit"
|
| 398 |
+
return base64.b64encode(self._video_bytes).decode("ascii")
|
| 399 |
+
|
| 400 |
+
|
| 401 |
+
_thread_local = threading.local()
|
| 402 |
+
_thread_local.http = httplib2.Http()
|
| 403 |
+
|
| 404 |
+
|
| 405 |
+
class GeneratedVideo(Video):
|
| 406 |
+
def __init__(self, *, uri):
|
| 407 |
+
self._uri = uri
|
| 408 |
+
self._loaded_bytes = None
|
| 409 |
+
|
| 410 |
+
@property
|
| 411 |
+
def _mime_type(self):
|
| 412 |
+
return "video/mp4"
|
| 413 |
+
|
| 414 |
+
@property
|
| 415 |
+
def _video_bytes(self) -> bytes:
|
| 416 |
+
if self._loaded_bytes is None:
|
| 417 |
+
self._loaded_bytes = self.download()
|
| 418 |
+
return self._loaded_bytes
|
| 419 |
+
|
| 420 |
+
@_video_bytes.setter
|
| 421 |
+
def _video_bytes(self, value: bytes):
|
| 422 |
+
self._loaded_bytes = value
|
| 423 |
+
|
| 424 |
+
def download(self):
|
| 425 |
+
api_key = client._client_manager.client_config["client_options"].api_key
|
| 426 |
+
|
| 427 |
+
request = googleapiclient.http.HttpRequest(
|
| 428 |
+
http=_thread_local.http,
|
| 429 |
+
postproc=googleapiclient.model.MediaModel().response,
|
| 430 |
+
uri=f"{self._uri}&key={api_key}",
|
| 431 |
+
method="GET",
|
| 432 |
+
headers={
|
| 433 |
+
"accept": "*/*",
|
| 434 |
+
"accept-encoding": "gzip, deflate",
|
| 435 |
+
"user-agent": "(gzip)",
|
| 436 |
+
"x-goog-api-client": "gdcl/2.151.0 gl-python/3.12.7",
|
| 437 |
+
},
|
| 438 |
+
)
|
| 439 |
+
result = request.execute()
|
| 440 |
+
return result
|
.venv/lib/python3.11/site-packages/google/generativeai/types/model_types.py
ADDED
|
@@ -0,0 +1,390 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Copyright 2023 Google LLC
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Type definitions for the models service."""
|
| 16 |
+
from __future__ import annotations
|
| 17 |
+
|
| 18 |
+
from collections.abc import Mapping
|
| 19 |
+
import csv
|
| 20 |
+
import dataclasses
|
| 21 |
+
import datetime
|
| 22 |
+
import json
|
| 23 |
+
import pathlib
|
| 24 |
+
import re
|
| 25 |
+
|
| 26 |
+
from typing import Any, Iterable, Union
|
| 27 |
+
|
| 28 |
+
import urllib.request
|
| 29 |
+
from typing_extensions import TypedDict
|
| 30 |
+
|
| 31 |
+
from google.generativeai import protos
|
| 32 |
+
from google.generativeai.types import permission_types
|
| 33 |
+
from google.generativeai import string_utils
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
__all__ = [
|
| 37 |
+
"Model",
|
| 38 |
+
"ModelNameOptions",
|
| 39 |
+
"AnyModelNameOptions",
|
| 40 |
+
"BaseModelNameOptions",
|
| 41 |
+
"TunedModelNameOptions",
|
| 42 |
+
"ModelsIterable",
|
| 43 |
+
"TunedModel",
|
| 44 |
+
"TunedModelState",
|
| 45 |
+
]
|
| 46 |
+
|
| 47 |
+
TunedModelState = protos.TunedModel.State
|
| 48 |
+
|
| 49 |
+
TunedModelStateOptions = Union[None, str, int, TunedModelState]
|
| 50 |
+
|
| 51 |
+
_TUNED_MODEL_VALID_NAME = r"[a-z](([a-z0-9-]{0,61}[a-z0-9])?)$"
|
| 52 |
+
TUNED_MODEL_NAME_ERROR_MSG = """The `name` must consist of alphanumeric characters (or -) and be at most 63 characters; The name you entered:
|
| 53 |
+
\tlen(name)== {length}
|
| 54 |
+
\tname={name}
|
| 55 |
+
"""
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def valid_tuned_model_name(name: str) -> bool:
|
| 59 |
+
return re.match(_TUNED_MODEL_VALID_NAME, name) is not None
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
# fmt: off
|
| 63 |
+
_TUNED_MODEL_STATES: dict[TunedModelStateOptions, TunedModelState] = {
|
| 64 |
+
TunedModelState.ACTIVE: TunedModelState.ACTIVE,
|
| 65 |
+
int(TunedModelState.ACTIVE): TunedModelState.ACTIVE,
|
| 66 |
+
"active": TunedModelState.ACTIVE,
|
| 67 |
+
|
| 68 |
+
TunedModelState.CREATING: TunedModelState.CREATING,
|
| 69 |
+
int(TunedModelState.CREATING): TunedModelState.CREATING,
|
| 70 |
+
"creating": TunedModelState.CREATING,
|
| 71 |
+
|
| 72 |
+
TunedModelState.FAILED: TunedModelState.FAILED,
|
| 73 |
+
int(TunedModelState.FAILED): TunedModelState.FAILED,
|
| 74 |
+
"failed": TunedModelState.FAILED,
|
| 75 |
+
|
| 76 |
+
TunedModelState.STATE_UNSPECIFIED: TunedModelState.STATE_UNSPECIFIED,
|
| 77 |
+
int(TunedModelState.STATE_UNSPECIFIED): TunedModelState.STATE_UNSPECIFIED,
|
| 78 |
+
"state_unspecified": TunedModelState.STATE_UNSPECIFIED,
|
| 79 |
+
"unspecified": TunedModelState.STATE_UNSPECIFIED,
|
| 80 |
+
None: TunedModelState.STATE_UNSPECIFIED,
|
| 81 |
+
}
|
| 82 |
+
# fmt: on
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def to_tuned_model_state(x: TunedModelStateOptions) -> TunedModelState:
|
| 86 |
+
if isinstance(x, str):
|
| 87 |
+
x = x.lower()
|
| 88 |
+
return _TUNED_MODEL_STATES[x]
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
@string_utils.prettyprint
|
| 92 |
+
@dataclasses.dataclass
|
| 93 |
+
class Model:
|
| 94 |
+
"""A dataclass representation of a `protos.Model`.
|
| 95 |
+
|
| 96 |
+
Attributes:
|
| 97 |
+
name: The resource name of the `Model`. Format: `models/{model}` with a `{model}` naming
|
| 98 |
+
convention of: "{base_model_id}-{version}". For example: `models/chat-bison-001`.
|
| 99 |
+
base_model_id: The base name of the model. For example: `chat-bison`.
|
| 100 |
+
version: The major version number of the model. For example: `001`.
|
| 101 |
+
display_name: The human-readable name of the model. E.g. `"Chat Bison"`. The name can be up
|
| 102 |
+
to 128 characters long and can consist of any UTF-8 characters.
|
| 103 |
+
description: A short description of the model.
|
| 104 |
+
input_token_limit: Maximum number of input tokens allowed for this model.
|
| 105 |
+
output_token_limit: Maximum number of output tokens available for this model.
|
| 106 |
+
supported_generation_methods: lists which methods are supported by the model. The method
|
| 107 |
+
names are defined as Pascal case strings, such as `generateMessage` which correspond to
|
| 108 |
+
API methods.
|
| 109 |
+
"""
|
| 110 |
+
|
| 111 |
+
name: str
|
| 112 |
+
base_model_id: str
|
| 113 |
+
version: str
|
| 114 |
+
display_name: str
|
| 115 |
+
description: str
|
| 116 |
+
input_token_limit: int
|
| 117 |
+
output_token_limit: int
|
| 118 |
+
supported_generation_methods: list[str]
|
| 119 |
+
temperature: float | None = None
|
| 120 |
+
max_temperature: float | None = None
|
| 121 |
+
top_p: float | None = None
|
| 122 |
+
top_k: int | None = None
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def _fix_microseconds(match):
|
| 126 |
+
# microseconds needs exactly 6 digits
|
| 127 |
+
fraction = float(match.group(0))
|
| 128 |
+
return f".{int(round(fraction*1e6)):06d}"
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def idecode_time(parent: dict["str", Any], name: str):
|
| 132 |
+
time = parent.pop(name, None)
|
| 133 |
+
if time is not None:
|
| 134 |
+
if "." in time:
|
| 135 |
+
time = re.sub(r"\.\d+", _fix_microseconds, time)
|
| 136 |
+
dt = datetime.datetime.strptime(time, "%Y-%m-%dT%H:%M:%S.%fZ")
|
| 137 |
+
else:
|
| 138 |
+
dt = datetime.datetime.strptime(time, "%Y-%m-%dT%H:%M:%SZ")
|
| 139 |
+
|
| 140 |
+
dt = dt.replace(tzinfo=datetime.timezone.utc)
|
| 141 |
+
parent[name] = dt
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def decode_tuned_model(tuned_model: protos.TunedModel | dict["str", Any]) -> TunedModel:
|
| 145 |
+
if isinstance(tuned_model, protos.TunedModel):
|
| 146 |
+
tuned_model = type(tuned_model).to_dict(
|
| 147 |
+
tuned_model, including_default_value_fields=False
|
| 148 |
+
) # pytype: disable=attribute-error
|
| 149 |
+
tuned_model["state"] = to_tuned_model_state(tuned_model.pop("state", None))
|
| 150 |
+
|
| 151 |
+
base_model = tuned_model.pop("base_model", None)
|
| 152 |
+
tuned_model_source = tuned_model.pop("tuned_model_source", None)
|
| 153 |
+
if base_model is not None:
|
| 154 |
+
tuned_model["base_model"] = base_model
|
| 155 |
+
tuned_model["source_model"] = base_model
|
| 156 |
+
elif tuned_model_source is not None:
|
| 157 |
+
tuned_model["base_model"] = tuned_model_source["base_model"]
|
| 158 |
+
tuned_model["source_model"] = tuned_model_source["tuned_model"]
|
| 159 |
+
|
| 160 |
+
idecode_time(tuned_model, "create_time")
|
| 161 |
+
idecode_time(tuned_model, "update_time")
|
| 162 |
+
|
| 163 |
+
task = tuned_model.pop("tuning_task", None)
|
| 164 |
+
if task is not None:
|
| 165 |
+
hype = task.pop("hyperparameters", None)
|
| 166 |
+
if hype is not None:
|
| 167 |
+
hype = Hyperparameters(**hype)
|
| 168 |
+
task["hyperparameters"] = hype
|
| 169 |
+
|
| 170 |
+
idecode_time(task, "start_time")
|
| 171 |
+
idecode_time(task, "complete_time")
|
| 172 |
+
|
| 173 |
+
snapshots = task.pop("snapshots", None)
|
| 174 |
+
if snapshots is not None:
|
| 175 |
+
for snap in snapshots:
|
| 176 |
+
idecode_time(snap, "compute_time")
|
| 177 |
+
task["snapshots"] = snapshots
|
| 178 |
+
task = TuningTask(**task)
|
| 179 |
+
tuned_model["tuning_task"] = task
|
| 180 |
+
return TunedModel(**tuned_model)
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
@string_utils.prettyprint
|
| 184 |
+
@dataclasses.dataclass
|
| 185 |
+
class TunedModel:
|
| 186 |
+
"""A dataclass representation of a `protos.TunedModel`."""
|
| 187 |
+
|
| 188 |
+
name: str | None = None
|
| 189 |
+
source_model: str | None = None
|
| 190 |
+
base_model: str | None = None
|
| 191 |
+
display_name: str = ""
|
| 192 |
+
description: str = ""
|
| 193 |
+
temperature: float | None = None
|
| 194 |
+
top_p: float | None = None
|
| 195 |
+
top_k: float | None = None
|
| 196 |
+
state: TunedModelState = TunedModelState.STATE_UNSPECIFIED
|
| 197 |
+
create_time: datetime.datetime | None = None
|
| 198 |
+
update_time: datetime.datetime | None = None
|
| 199 |
+
tuning_task: TuningTask | None = None
|
| 200 |
+
reader_project_numbers: list[int] | None = None
|
| 201 |
+
|
| 202 |
+
@property
|
| 203 |
+
def permissions(self) -> permission_types.Permissions:
|
| 204 |
+
return permission_types.Permissions(self)
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
@string_utils.prettyprint
|
| 208 |
+
@dataclasses.dataclass
|
| 209 |
+
class TuningTask:
|
| 210 |
+
start_time: datetime.datetime | None = None
|
| 211 |
+
complete_time: datetime.datetime | None = None
|
| 212 |
+
snapshots: list[TuningSnapshot] = dataclasses.field(default_factory=list)
|
| 213 |
+
hyperparameters: Hyperparameters | None = None
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
class TuningExampleDict(TypedDict):
|
| 217 |
+
text_input: str
|
| 218 |
+
output: str
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
TuningExampleOptions = Union[TuningExampleDict, protos.TuningExample, tuple[str, str], list[str]]
|
| 222 |
+
|
| 223 |
+
# TODO(markdaoust): gs:// URLS? File-type argument for files without extension?
|
| 224 |
+
TuningDataOptions = Union[
|
| 225 |
+
pathlib.Path,
|
| 226 |
+
str,
|
| 227 |
+
protos.Dataset,
|
| 228 |
+
Mapping[str, Iterable[str]],
|
| 229 |
+
Iterable[TuningExampleOptions],
|
| 230 |
+
]
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
def encode_tuning_data(
|
| 234 |
+
data: TuningDataOptions, input_key="text_input", output_key="output"
|
| 235 |
+
) -> protos.Dataset:
|
| 236 |
+
if isinstance(data, protos.Dataset):
|
| 237 |
+
return data
|
| 238 |
+
|
| 239 |
+
if isinstance(data, str):
|
| 240 |
+
# Strings are either URLs or system paths.
|
| 241 |
+
if re.match(r"^\w+://\S+$", data):
|
| 242 |
+
data = _normalize_url(data)
|
| 243 |
+
else:
|
| 244 |
+
# Normalize system paths to use pathlib
|
| 245 |
+
data = pathlib.Path(data)
|
| 246 |
+
|
| 247 |
+
if isinstance(data, (str, pathlib.Path)):
|
| 248 |
+
if isinstance(data, str):
|
| 249 |
+
f = urllib.request.urlopen(data)
|
| 250 |
+
# csv needs strings, json does not.
|
| 251 |
+
content = (line.decode("utf-8") for line in f)
|
| 252 |
+
else:
|
| 253 |
+
f = data.open("r")
|
| 254 |
+
content = f
|
| 255 |
+
|
| 256 |
+
if str(data).lower().endswith(".json"):
|
| 257 |
+
with f:
|
| 258 |
+
data = json.load(f)
|
| 259 |
+
else:
|
| 260 |
+
with f:
|
| 261 |
+
data = csv.DictReader(content)
|
| 262 |
+
return _convert_iterable(data, input_key, output_key)
|
| 263 |
+
|
| 264 |
+
if hasattr(data, "keys"):
|
| 265 |
+
return _convert_dict(data, input_key, output_key)
|
| 266 |
+
else:
|
| 267 |
+
return _convert_iterable(data, input_key, output_key)
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
def _normalize_url(url: str) -> str:
|
| 271 |
+
sheet_base = "https://docs.google.com/spreadsheets"
|
| 272 |
+
if url.startswith(sheet_base):
|
| 273 |
+
# Normalize google-sheets URLs to download the csv.
|
| 274 |
+
id_match = re.match(f"{sheet_base}/d/[^/]+", url)
|
| 275 |
+
if id_match is None:
|
| 276 |
+
raise ValueError("Incomplete Google Sheets URL: {data}")
|
| 277 |
+
|
| 278 |
+
if tab_match := re.search(r"gid=(\d+)", url):
|
| 279 |
+
tab_param = f"&gid={tab_match.group(1)}"
|
| 280 |
+
else:
|
| 281 |
+
tab_param = ""
|
| 282 |
+
|
| 283 |
+
url = f"{id_match.group(0)}/export?format=csv{tab_param}"
|
| 284 |
+
|
| 285 |
+
return url
|
| 286 |
+
|
| 287 |
+
|
| 288 |
+
def _convert_dict(data, input_key, output_key):
|
| 289 |
+
new_data = list()
|
| 290 |
+
|
| 291 |
+
try:
|
| 292 |
+
inputs = data[input_key]
|
| 293 |
+
except KeyError:
|
| 294 |
+
raise KeyError(
|
| 295 |
+
f"Invalid key: The input key '{input_key}' does not exist in the data. "
|
| 296 |
+
f"Available keys are: {sorted(data.keys())}."
|
| 297 |
+
)
|
| 298 |
+
|
| 299 |
+
try:
|
| 300 |
+
outputs = data[output_key]
|
| 301 |
+
except KeyError:
|
| 302 |
+
raise KeyError(
|
| 303 |
+
f"Invalid key: The output key '{output_key}' does not exist in the data. "
|
| 304 |
+
f"Available keys are: {sorted(data.keys())}."
|
| 305 |
+
)
|
| 306 |
+
|
| 307 |
+
for i, o in zip(inputs, outputs):
|
| 308 |
+
new_data.append(protos.TuningExample({"text_input": str(i), "output": str(o)}))
|
| 309 |
+
return protos.Dataset(examples=protos.TuningExamples(examples=new_data))
|
| 310 |
+
|
| 311 |
+
|
| 312 |
+
def _convert_iterable(data, input_key, output_key):
|
| 313 |
+
new_data = list()
|
| 314 |
+
for example in data:
|
| 315 |
+
example = encode_tuning_example(example, input_key, output_key)
|
| 316 |
+
new_data.append(example)
|
| 317 |
+
return protos.Dataset(examples=protos.TuningExamples(examples=new_data))
|
| 318 |
+
|
| 319 |
+
|
| 320 |
+
def encode_tuning_example(example: TuningExampleOptions, input_key, output_key):
|
| 321 |
+
if isinstance(example, protos.TuningExample):
|
| 322 |
+
return example
|
| 323 |
+
elif isinstance(example, (tuple, list)):
|
| 324 |
+
a, b = example
|
| 325 |
+
example = protos.TuningExample(text_input=a, output=b)
|
| 326 |
+
else: # dict
|
| 327 |
+
example = protos.TuningExample(text_input=example[input_key], output=example[output_key])
|
| 328 |
+
return example
|
| 329 |
+
|
| 330 |
+
|
| 331 |
+
@string_utils.prettyprint
|
| 332 |
+
@dataclasses.dataclass
|
| 333 |
+
class TuningSnapshot:
|
| 334 |
+
step: int
|
| 335 |
+
epoch: int
|
| 336 |
+
mean_score: float
|
| 337 |
+
compute_time: datetime.datetime
|
| 338 |
+
|
| 339 |
+
|
| 340 |
+
@string_utils.prettyprint
|
| 341 |
+
@dataclasses.dataclass
|
| 342 |
+
class Hyperparameters:
|
| 343 |
+
epoch_count: int = 0
|
| 344 |
+
batch_size: int = 0
|
| 345 |
+
learning_rate: float = 0.0
|
| 346 |
+
|
| 347 |
+
|
| 348 |
+
BaseModelNameOptions = Union[str, Model, protos.Model]
|
| 349 |
+
TunedModelNameOptions = Union[str, TunedModel, protos.TunedModel]
|
| 350 |
+
AnyModelNameOptions = Union[str, Model, protos.Model, TunedModel, protos.TunedModel]
|
| 351 |
+
ModelNameOptions = AnyModelNameOptions
|
| 352 |
+
|
| 353 |
+
|
| 354 |
+
def make_model_name(name: AnyModelNameOptions):
|
| 355 |
+
if isinstance(name, (Model, protos.Model, TunedModel, protos.TunedModel)):
|
| 356 |
+
name = name.name # pytype: disable=attribute-error
|
| 357 |
+
elif isinstance(name, str):
|
| 358 |
+
name = name
|
| 359 |
+
else:
|
| 360 |
+
raise TypeError(
|
| 361 |
+
"Invalid input type. Expected one of the following types: `str`, `Model`, or `TunedModel`."
|
| 362 |
+
)
|
| 363 |
+
|
| 364 |
+
if not (name.startswith("models/") or name.startswith("tunedModels/")):
|
| 365 |
+
raise ValueError(
|
| 366 |
+
f"Invalid model name: '{name}'. Model names should start with 'models/' or 'tunedModels/'."
|
| 367 |
+
)
|
| 368 |
+
|
| 369 |
+
return name
|
| 370 |
+
|
| 371 |
+
|
| 372 |
+
ModelsIterable = Iterable[Model]
|
| 373 |
+
TunedModelsIterable = Iterable[TunedModel]
|
| 374 |
+
|
| 375 |
+
|
| 376 |
+
@string_utils.prettyprint
|
| 377 |
+
@dataclasses.dataclass
|
| 378 |
+
class TokenCount:
|
| 379 |
+
"""A dataclass representation of a `protos.TokenCountResponse`.
|
| 380 |
+
|
| 381 |
+
Attributes:
|
| 382 |
+
token_count: The number of tokens returned by the model's tokenizer for the `input_text`.
|
| 383 |
+
token_count_limit:
|
| 384 |
+
"""
|
| 385 |
+
|
| 386 |
+
token_count: int
|
| 387 |
+
token_count_limit: int
|
| 388 |
+
|
| 389 |
+
def over_limit(self):
|
| 390 |
+
return self.token_count > self.token_count_limit
|
.venv/lib/python3.11/site-packages/google/generativeai/types/palm_safety_types.py
ADDED
|
@@ -0,0 +1,286 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Copyright 2023 Google LLC
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
from __future__ import annotations
|
| 16 |
+
|
| 17 |
+
from collections.abc import Mapping
|
| 18 |
+
|
| 19 |
+
import enum
|
| 20 |
+
import typing
|
| 21 |
+
from typing import Dict, Iterable, List, Union
|
| 22 |
+
|
| 23 |
+
from typing_extensions import TypedDict
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
from google.generativeai import protos
|
| 27 |
+
from google.generativeai import string_utils
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
__all__ = [
|
| 31 |
+
"HarmCategory",
|
| 32 |
+
"HarmProbability",
|
| 33 |
+
"HarmBlockThreshold",
|
| 34 |
+
"BlockedReason",
|
| 35 |
+
"ContentFilterDict",
|
| 36 |
+
"SafetyRatingDict",
|
| 37 |
+
"SafetySettingDict",
|
| 38 |
+
"SafetyFeedbackDict",
|
| 39 |
+
]
|
| 40 |
+
|
| 41 |
+
# These are basic python enums, it's okay to expose them
|
| 42 |
+
HarmProbability = protos.SafetyRating.HarmProbability
|
| 43 |
+
HarmBlockThreshold = protos.SafetySetting.HarmBlockThreshold
|
| 44 |
+
BlockedReason = protos.ContentFilter.BlockedReason
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class HarmCategory:
|
| 48 |
+
"""
|
| 49 |
+
Harm Categories supported by the palm-family models
|
| 50 |
+
"""
|
| 51 |
+
|
| 52 |
+
HARM_CATEGORY_UNSPECIFIED = protos.HarmCategory.HARM_CATEGORY_UNSPECIFIED.value
|
| 53 |
+
HARM_CATEGORY_DEROGATORY = protos.HarmCategory.HARM_CATEGORY_DEROGATORY.value
|
| 54 |
+
HARM_CATEGORY_TOXICITY = protos.HarmCategory.HARM_CATEGORY_TOXICITY.value
|
| 55 |
+
HARM_CATEGORY_VIOLENCE = protos.HarmCategory.HARM_CATEGORY_VIOLENCE.value
|
| 56 |
+
HARM_CATEGORY_SEXUAL = protos.HarmCategory.HARM_CATEGORY_SEXUAL.value
|
| 57 |
+
HARM_CATEGORY_MEDICAL = protos.HarmCategory.HARM_CATEGORY_MEDICAL.value
|
| 58 |
+
HARM_CATEGORY_DANGEROUS = protos.HarmCategory.HARM_CATEGORY_DANGEROUS.value
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
HarmCategoryOptions = Union[str, int, HarmCategory]
|
| 62 |
+
|
| 63 |
+
# fmt: off
|
| 64 |
+
_HARM_CATEGORIES: Dict[HarmCategoryOptions, protos.HarmCategory] = {
|
| 65 |
+
protos.HarmCategory.HARM_CATEGORY_UNSPECIFIED: protos.HarmCategory.HARM_CATEGORY_UNSPECIFIED,
|
| 66 |
+
HarmCategory.HARM_CATEGORY_UNSPECIFIED: protos.HarmCategory.HARM_CATEGORY_UNSPECIFIED,
|
| 67 |
+
0: protos.HarmCategory.HARM_CATEGORY_UNSPECIFIED,
|
| 68 |
+
"harm_category_unspecified": protos.HarmCategory.HARM_CATEGORY_UNSPECIFIED,
|
| 69 |
+
"unspecified": protos.HarmCategory.HARM_CATEGORY_UNSPECIFIED,
|
| 70 |
+
|
| 71 |
+
protos.HarmCategory.HARM_CATEGORY_DEROGATORY: protos.HarmCategory.HARM_CATEGORY_DEROGATORY,
|
| 72 |
+
HarmCategory.HARM_CATEGORY_DEROGATORY: protos.HarmCategory.HARM_CATEGORY_DEROGATORY,
|
| 73 |
+
1: protos.HarmCategory.HARM_CATEGORY_DEROGATORY,
|
| 74 |
+
"harm_category_derogatory": protos.HarmCategory.HARM_CATEGORY_DEROGATORY,
|
| 75 |
+
"derogatory": protos.HarmCategory.HARM_CATEGORY_DEROGATORY,
|
| 76 |
+
|
| 77 |
+
protos.HarmCategory.HARM_CATEGORY_TOXICITY: protos.HarmCategory.HARM_CATEGORY_TOXICITY,
|
| 78 |
+
HarmCategory.HARM_CATEGORY_TOXICITY: protos.HarmCategory.HARM_CATEGORY_TOXICITY,
|
| 79 |
+
2: protos.HarmCategory.HARM_CATEGORY_TOXICITY,
|
| 80 |
+
"harm_category_toxicity": protos.HarmCategory.HARM_CATEGORY_TOXICITY,
|
| 81 |
+
"toxicity": protos.HarmCategory.HARM_CATEGORY_TOXICITY,
|
| 82 |
+
"toxic": protos.HarmCategory.HARM_CATEGORY_TOXICITY,
|
| 83 |
+
|
| 84 |
+
protos.HarmCategory.HARM_CATEGORY_VIOLENCE: protos.HarmCategory.HARM_CATEGORY_VIOLENCE,
|
| 85 |
+
HarmCategory.HARM_CATEGORY_VIOLENCE: protos.HarmCategory.HARM_CATEGORY_VIOLENCE,
|
| 86 |
+
3: protos.HarmCategory.HARM_CATEGORY_VIOLENCE,
|
| 87 |
+
"harm_category_violence": protos.HarmCategory.HARM_CATEGORY_VIOLENCE,
|
| 88 |
+
"violence": protos.HarmCategory.HARM_CATEGORY_VIOLENCE,
|
| 89 |
+
"violent": protos.HarmCategory.HARM_CATEGORY_VIOLENCE,
|
| 90 |
+
|
| 91 |
+
protos.HarmCategory.HARM_CATEGORY_SEXUAL: protos.HarmCategory.HARM_CATEGORY_SEXUAL,
|
| 92 |
+
HarmCategory.HARM_CATEGORY_SEXUAL: protos.HarmCategory.HARM_CATEGORY_SEXUAL,
|
| 93 |
+
4: protos.HarmCategory.HARM_CATEGORY_SEXUAL,
|
| 94 |
+
"harm_category_sexual": protos.HarmCategory.HARM_CATEGORY_SEXUAL,
|
| 95 |
+
"sexual": protos.HarmCategory.HARM_CATEGORY_SEXUAL,
|
| 96 |
+
"sex": protos.HarmCategory.HARM_CATEGORY_SEXUAL,
|
| 97 |
+
|
| 98 |
+
protos.HarmCategory.HARM_CATEGORY_MEDICAL: protos.HarmCategory.HARM_CATEGORY_MEDICAL,
|
| 99 |
+
HarmCategory.HARM_CATEGORY_MEDICAL: protos.HarmCategory.HARM_CATEGORY_MEDICAL,
|
| 100 |
+
5: protos.HarmCategory.HARM_CATEGORY_MEDICAL,
|
| 101 |
+
"harm_category_medical": protos.HarmCategory.HARM_CATEGORY_MEDICAL,
|
| 102 |
+
"medical": protos.HarmCategory.HARM_CATEGORY_MEDICAL,
|
| 103 |
+
"med": protos.HarmCategory.HARM_CATEGORY_MEDICAL,
|
| 104 |
+
|
| 105 |
+
protos.HarmCategory.HARM_CATEGORY_DANGEROUS: protos.HarmCategory.HARM_CATEGORY_DANGEROUS,
|
| 106 |
+
HarmCategory.HARM_CATEGORY_DANGEROUS: protos.HarmCategory.HARM_CATEGORY_DANGEROUS,
|
| 107 |
+
6: protos.HarmCategory.HARM_CATEGORY_DANGEROUS,
|
| 108 |
+
"harm_category_dangerous": protos.HarmCategory.HARM_CATEGORY_DANGEROUS,
|
| 109 |
+
"dangerous": protos.HarmCategory.HARM_CATEGORY_DANGEROUS,
|
| 110 |
+
"danger": protos.HarmCategory.HARM_CATEGORY_DANGEROUS,
|
| 111 |
+
}
|
| 112 |
+
# fmt: on
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def to_harm_category(x: HarmCategoryOptions) -> protos.HarmCategory:
|
| 116 |
+
if isinstance(x, str):
|
| 117 |
+
x = x.lower()
|
| 118 |
+
return _HARM_CATEGORIES[x]
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
HarmBlockThresholdOptions = Union[str, int, HarmBlockThreshold]
|
| 122 |
+
|
| 123 |
+
# fmt: off
|
| 124 |
+
_BLOCK_THRESHOLDS: Dict[HarmBlockThresholdOptions, HarmBlockThreshold] = {
|
| 125 |
+
HarmBlockThreshold.HARM_BLOCK_THRESHOLD_UNSPECIFIED: HarmBlockThreshold.HARM_BLOCK_THRESHOLD_UNSPECIFIED,
|
| 126 |
+
0: HarmBlockThreshold.HARM_BLOCK_THRESHOLD_UNSPECIFIED,
|
| 127 |
+
"harm_block_threshold_unspecified": HarmBlockThreshold.HARM_BLOCK_THRESHOLD_UNSPECIFIED,
|
| 128 |
+
"block_threshold_unspecified": HarmBlockThreshold.HARM_BLOCK_THRESHOLD_UNSPECIFIED,
|
| 129 |
+
"unspecified": HarmBlockThreshold.HARM_BLOCK_THRESHOLD_UNSPECIFIED,
|
| 130 |
+
|
| 131 |
+
HarmBlockThreshold.BLOCK_LOW_AND_ABOVE: HarmBlockThreshold.BLOCK_LOW_AND_ABOVE,
|
| 132 |
+
1: HarmBlockThreshold.BLOCK_LOW_AND_ABOVE,
|
| 133 |
+
"block_low_and_above": HarmBlockThreshold.BLOCK_LOW_AND_ABOVE,
|
| 134 |
+
"low": HarmBlockThreshold.BLOCK_LOW_AND_ABOVE,
|
| 135 |
+
|
| 136 |
+
HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
|
| 137 |
+
2: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
|
| 138 |
+
"block_medium_and_above": HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
|
| 139 |
+
"medium": HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
|
| 140 |
+
"med": HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
|
| 141 |
+
|
| 142 |
+
HarmBlockThreshold.BLOCK_ONLY_HIGH: HarmBlockThreshold.BLOCK_ONLY_HIGH,
|
| 143 |
+
3: HarmBlockThreshold.BLOCK_ONLY_HIGH,
|
| 144 |
+
"block_only_high": HarmBlockThreshold.BLOCK_ONLY_HIGH,
|
| 145 |
+
"high": HarmBlockThreshold.BLOCK_ONLY_HIGH,
|
| 146 |
+
|
| 147 |
+
HarmBlockThreshold.BLOCK_NONE: HarmBlockThreshold.BLOCK_NONE,
|
| 148 |
+
4: HarmBlockThreshold.BLOCK_NONE,
|
| 149 |
+
"block_none": HarmBlockThreshold.BLOCK_NONE,
|
| 150 |
+
}
|
| 151 |
+
# fmt: on
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
def to_block_threshold(x: HarmBlockThresholdOptions) -> HarmBlockThreshold:
|
| 155 |
+
if isinstance(x, str):
|
| 156 |
+
x = x.lower()
|
| 157 |
+
return _BLOCK_THRESHOLDS[x]
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
class ContentFilterDict(TypedDict):
|
| 161 |
+
reason: BlockedReason
|
| 162 |
+
message: str
|
| 163 |
+
|
| 164 |
+
__doc__ = string_utils.strip_oneof(protos.ContentFilter.__doc__)
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
def convert_filters_to_enums(
|
| 168 |
+
filters: Iterable[dict],
|
| 169 |
+
) -> List[ContentFilterDict]:
|
| 170 |
+
result = []
|
| 171 |
+
for f in filters:
|
| 172 |
+
f = f.copy()
|
| 173 |
+
f["reason"] = BlockedReason(f["reason"])
|
| 174 |
+
f = typing.cast(ContentFilterDict, f)
|
| 175 |
+
result.append(f)
|
| 176 |
+
return result
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
class SafetyRatingDict(TypedDict):
|
| 180 |
+
category: protos.HarmCategory
|
| 181 |
+
probability: HarmProbability
|
| 182 |
+
|
| 183 |
+
__doc__ = string_utils.strip_oneof(protos.SafetyRating.__doc__)
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
def convert_rating_to_enum(rating: dict) -> SafetyRatingDict:
|
| 187 |
+
return {
|
| 188 |
+
"category": protos.HarmCategory(rating["category"]),
|
| 189 |
+
"probability": HarmProbability(rating["probability"]),
|
| 190 |
+
}
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
def convert_ratings_to_enum(ratings: Iterable[dict]) -> List[SafetyRatingDict]:
|
| 194 |
+
result = []
|
| 195 |
+
for r in ratings:
|
| 196 |
+
result.append(convert_rating_to_enum(r))
|
| 197 |
+
return result
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
class SafetySettingDict(TypedDict):
|
| 201 |
+
category: protos.HarmCategory
|
| 202 |
+
threshold: HarmBlockThreshold
|
| 203 |
+
|
| 204 |
+
__doc__ = string_utils.strip_oneof(protos.SafetySetting.__doc__)
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
class LooseSafetySettingDict(TypedDict):
|
| 208 |
+
category: HarmCategoryOptions
|
| 209 |
+
threshold: HarmBlockThresholdOptions
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
EasySafetySetting = Mapping[HarmCategoryOptions, HarmBlockThresholdOptions]
|
| 213 |
+
EasySafetySettingDict = dict[HarmCategoryOptions, HarmBlockThresholdOptions]
|
| 214 |
+
|
| 215 |
+
SafetySettingOptions = Union[EasySafetySetting, Iterable[LooseSafetySettingDict], None]
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
def to_easy_safety_dict(settings: SafetySettingOptions) -> EasySafetySettingDict:
|
| 219 |
+
if settings is None:
|
| 220 |
+
return {}
|
| 221 |
+
elif isinstance(settings, Mapping):
|
| 222 |
+
return {to_harm_category(key): to_block_threshold(value) for key, value in settings.items()}
|
| 223 |
+
else: # Iterable
|
| 224 |
+
return {
|
| 225 |
+
to_harm_category(d["category"]): to_block_threshold(d["threshold"]) for d in settings
|
| 226 |
+
}
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
def normalize_safety_settings(
|
| 230 |
+
settings: SafetySettingOptions,
|
| 231 |
+
) -> list[SafetySettingDict] | None:
|
| 232 |
+
if settings is None:
|
| 233 |
+
return None
|
| 234 |
+
if isinstance(settings, Mapping):
|
| 235 |
+
return [
|
| 236 |
+
{
|
| 237 |
+
"category": to_harm_category(key),
|
| 238 |
+
"threshold": to_block_threshold(value),
|
| 239 |
+
}
|
| 240 |
+
for key, value in settings.items()
|
| 241 |
+
]
|
| 242 |
+
else:
|
| 243 |
+
return [
|
| 244 |
+
{
|
| 245 |
+
"category": to_harm_category(d["category"]),
|
| 246 |
+
"threshold": to_block_threshold(d["threshold"]),
|
| 247 |
+
}
|
| 248 |
+
for d in settings
|
| 249 |
+
]
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
def convert_setting_to_enum(setting: dict) -> SafetySettingDict:
|
| 253 |
+
return {
|
| 254 |
+
"category": protos.HarmCategory(setting["category"]),
|
| 255 |
+
"threshold": HarmBlockThreshold(setting["threshold"]),
|
| 256 |
+
}
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
class SafetyFeedbackDict(TypedDict):
|
| 260 |
+
rating: SafetyRatingDict
|
| 261 |
+
setting: SafetySettingDict
|
| 262 |
+
|
| 263 |
+
__doc__ = string_utils.strip_oneof(protos.SafetyFeedback.__doc__)
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
def convert_safety_feedback_to_enums(
|
| 267 |
+
safety_feedback: Iterable[dict],
|
| 268 |
+
) -> List[SafetyFeedbackDict]:
|
| 269 |
+
result = []
|
| 270 |
+
for sf in safety_feedback:
|
| 271 |
+
result.append(
|
| 272 |
+
{
|
| 273 |
+
"rating": convert_rating_to_enum(sf["rating"]),
|
| 274 |
+
"setting": convert_setting_to_enum(sf["setting"]),
|
| 275 |
+
}
|
| 276 |
+
)
|
| 277 |
+
return result
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
def convert_candidate_enums(candidates):
|
| 281 |
+
result = []
|
| 282 |
+
for candidate in candidates:
|
| 283 |
+
candidate = candidate.copy()
|
| 284 |
+
candidate["safety_ratings"] = convert_ratings_to_enum(candidate["safety_ratings"])
|
| 285 |
+
result.append(candidate)
|
| 286 |
+
return result
|
.venv/lib/python3.11/site-packages/google/generativeai/types/safety_types.py
ADDED
|
@@ -0,0 +1,303 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Copyright 2023 Google LLC
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
from __future__ import annotations
|
| 16 |
+
|
| 17 |
+
from collections.abc import Mapping
|
| 18 |
+
|
| 19 |
+
import enum
|
| 20 |
+
import typing
|
| 21 |
+
from typing import Dict, Iterable, List, Union
|
| 22 |
+
|
| 23 |
+
from typing_extensions import TypedDict
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
from google.generativeai import protos
|
| 27 |
+
from google.generativeai import string_utils
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
__all__ = [
|
| 31 |
+
"HarmCategory",
|
| 32 |
+
"HarmProbability",
|
| 33 |
+
"HarmBlockThreshold",
|
| 34 |
+
"BlockedReason",
|
| 35 |
+
"ContentFilterDict",
|
| 36 |
+
"SafetyRatingDict",
|
| 37 |
+
"SafetySettingDict",
|
| 38 |
+
"SafetyFeedbackDict",
|
| 39 |
+
]
|
| 40 |
+
|
| 41 |
+
# These are basic python enums, it's okay to expose them
|
| 42 |
+
HarmProbability = protos.SafetyRating.HarmProbability
|
| 43 |
+
HarmBlockThreshold = protos.SafetySetting.HarmBlockThreshold
|
| 44 |
+
BlockedReason = protos.ContentFilter.BlockedReason
|
| 45 |
+
|
| 46 |
+
import proto
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class HarmCategory(proto.Enum):
|
| 50 |
+
"""
|
| 51 |
+
Harm Categories supported by the gemini-family model
|
| 52 |
+
"""
|
| 53 |
+
|
| 54 |
+
HARM_CATEGORY_UNSPECIFIED = protos.HarmCategory.HARM_CATEGORY_UNSPECIFIED.value
|
| 55 |
+
HARM_CATEGORY_HARASSMENT = protos.HarmCategory.HARM_CATEGORY_HARASSMENT.value
|
| 56 |
+
HARM_CATEGORY_HATE_SPEECH = protos.HarmCategory.HARM_CATEGORY_HATE_SPEECH.value
|
| 57 |
+
HARM_CATEGORY_SEXUALLY_EXPLICIT = protos.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT.value
|
| 58 |
+
HARM_CATEGORY_DANGEROUS_CONTENT = protos.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT.value
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
HarmCategoryOptions = Union[str, int, HarmCategory]
|
| 62 |
+
|
| 63 |
+
# fmt: off
|
| 64 |
+
_HARM_CATEGORIES: Dict[HarmCategoryOptions, protos.HarmCategory] = {
|
| 65 |
+
protos.HarmCategory.HARM_CATEGORY_UNSPECIFIED: protos.HarmCategory.HARM_CATEGORY_UNSPECIFIED,
|
| 66 |
+
HarmCategory.HARM_CATEGORY_UNSPECIFIED: protos.HarmCategory.HARM_CATEGORY_UNSPECIFIED,
|
| 67 |
+
0: protos.HarmCategory.HARM_CATEGORY_UNSPECIFIED,
|
| 68 |
+
"harm_category_unspecified": protos.HarmCategory.HARM_CATEGORY_UNSPECIFIED,
|
| 69 |
+
"unspecified": protos.HarmCategory.HARM_CATEGORY_UNSPECIFIED,
|
| 70 |
+
|
| 71 |
+
7: protos.HarmCategory.HARM_CATEGORY_HARASSMENT,
|
| 72 |
+
protos.HarmCategory.HARM_CATEGORY_HARASSMENT: protos.HarmCategory.HARM_CATEGORY_HARASSMENT,
|
| 73 |
+
HarmCategory.HARM_CATEGORY_HARASSMENT: protos.HarmCategory.HARM_CATEGORY_HARASSMENT,
|
| 74 |
+
"harm_category_harassment": protos.HarmCategory.HARM_CATEGORY_HARASSMENT,
|
| 75 |
+
"harassment": protos.HarmCategory.HARM_CATEGORY_HARASSMENT,
|
| 76 |
+
|
| 77 |
+
8: protos.HarmCategory.HARM_CATEGORY_HATE_SPEECH,
|
| 78 |
+
protos.HarmCategory.HARM_CATEGORY_HATE_SPEECH: protos.HarmCategory.HARM_CATEGORY_HATE_SPEECH,
|
| 79 |
+
HarmCategory.HARM_CATEGORY_HATE_SPEECH: protos.HarmCategory.HARM_CATEGORY_HATE_SPEECH,
|
| 80 |
+
'harm_category_hate_speech': protos.HarmCategory.HARM_CATEGORY_HATE_SPEECH,
|
| 81 |
+
'hate_speech': protos.HarmCategory.HARM_CATEGORY_HATE_SPEECH,
|
| 82 |
+
'hate': protos.HarmCategory.HARM_CATEGORY_HATE_SPEECH,
|
| 83 |
+
|
| 84 |
+
9: protos.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,
|
| 85 |
+
protos.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: protos.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,
|
| 86 |
+
HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: protos.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,
|
| 87 |
+
"harm_category_sexually_explicit": protos.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,
|
| 88 |
+
"harm_category_sexual": protos.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,
|
| 89 |
+
"sexually_explicit": protos.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,
|
| 90 |
+
"sexual": protos.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,
|
| 91 |
+
"sex": protos.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,
|
| 92 |
+
|
| 93 |
+
10: protos.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
|
| 94 |
+
protos.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: protos.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
|
| 95 |
+
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: protos.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
|
| 96 |
+
"harm_category_dangerous_content": protos.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
|
| 97 |
+
"harm_category_dangerous": protos.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
|
| 98 |
+
"dangerous": protos.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
|
| 99 |
+
"danger": protos.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
|
| 100 |
+
}
|
| 101 |
+
# fmt: on
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def to_harm_category(x: HarmCategoryOptions) -> protos.HarmCategory:
|
| 105 |
+
if isinstance(x, str):
|
| 106 |
+
x = x.lower()
|
| 107 |
+
return _HARM_CATEGORIES[x]
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
HarmBlockThresholdOptions = Union[str, int, HarmBlockThreshold]
|
| 111 |
+
|
| 112 |
+
# fmt: off
|
| 113 |
+
_BLOCK_THRESHOLDS: Dict[HarmBlockThresholdOptions, HarmBlockThreshold] = {
|
| 114 |
+
HarmBlockThreshold.HARM_BLOCK_THRESHOLD_UNSPECIFIED: HarmBlockThreshold.HARM_BLOCK_THRESHOLD_UNSPECIFIED,
|
| 115 |
+
0: HarmBlockThreshold.HARM_BLOCK_THRESHOLD_UNSPECIFIED,
|
| 116 |
+
"harm_block_threshold_unspecified": HarmBlockThreshold.HARM_BLOCK_THRESHOLD_UNSPECIFIED,
|
| 117 |
+
"block_threshold_unspecified": HarmBlockThreshold.HARM_BLOCK_THRESHOLD_UNSPECIFIED,
|
| 118 |
+
"unspecified": HarmBlockThreshold.HARM_BLOCK_THRESHOLD_UNSPECIFIED,
|
| 119 |
+
|
| 120 |
+
HarmBlockThreshold.BLOCK_LOW_AND_ABOVE: HarmBlockThreshold.BLOCK_LOW_AND_ABOVE,
|
| 121 |
+
1: HarmBlockThreshold.BLOCK_LOW_AND_ABOVE,
|
| 122 |
+
"block_low_and_above": HarmBlockThreshold.BLOCK_LOW_AND_ABOVE,
|
| 123 |
+
"low": HarmBlockThreshold.BLOCK_LOW_AND_ABOVE,
|
| 124 |
+
|
| 125 |
+
HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
|
| 126 |
+
2: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
|
| 127 |
+
"block_medium_and_above": HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
|
| 128 |
+
"medium": HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
|
| 129 |
+
"med": HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
|
| 130 |
+
|
| 131 |
+
HarmBlockThreshold.BLOCK_ONLY_HIGH: HarmBlockThreshold.BLOCK_ONLY_HIGH,
|
| 132 |
+
3: HarmBlockThreshold.BLOCK_ONLY_HIGH,
|
| 133 |
+
"block_only_high": HarmBlockThreshold.BLOCK_ONLY_HIGH,
|
| 134 |
+
"high": HarmBlockThreshold.BLOCK_ONLY_HIGH,
|
| 135 |
+
|
| 136 |
+
HarmBlockThreshold.BLOCK_NONE: HarmBlockThreshold.BLOCK_NONE,
|
| 137 |
+
4: HarmBlockThreshold.BLOCK_NONE,
|
| 138 |
+
"block_none": HarmBlockThreshold.BLOCK_NONE,
|
| 139 |
+
}
|
| 140 |
+
# fmt: on
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
def to_block_threshold(x: HarmBlockThresholdOptions) -> HarmBlockThreshold:
|
| 144 |
+
if isinstance(x, str):
|
| 145 |
+
x = x.lower()
|
| 146 |
+
return _BLOCK_THRESHOLDS[x]
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
class ContentFilterDict(TypedDict):
|
| 150 |
+
reason: BlockedReason
|
| 151 |
+
message: str
|
| 152 |
+
|
| 153 |
+
__doc__ = string_utils.strip_oneof(protos.ContentFilter.__doc__)
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def convert_filters_to_enums(
|
| 157 |
+
filters: Iterable[dict],
|
| 158 |
+
) -> List[ContentFilterDict]:
|
| 159 |
+
result = []
|
| 160 |
+
for f in filters:
|
| 161 |
+
f = f.copy()
|
| 162 |
+
f["reason"] = BlockedReason(f["reason"])
|
| 163 |
+
f = typing.cast(ContentFilterDict, f)
|
| 164 |
+
result.append(f)
|
| 165 |
+
return result
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
class SafetyRatingDict(TypedDict):
|
| 169 |
+
category: protos.HarmCategory
|
| 170 |
+
probability: HarmProbability
|
| 171 |
+
|
| 172 |
+
__doc__ = string_utils.strip_oneof(protos.SafetyRating.__doc__)
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
def convert_rating_to_enum(rating: dict) -> SafetyRatingDict:
|
| 176 |
+
return {
|
| 177 |
+
"category": protos.HarmCategory(rating["category"]),
|
| 178 |
+
"probability": HarmProbability(rating["probability"]),
|
| 179 |
+
}
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
def convert_ratings_to_enum(ratings: Iterable[dict]) -> List[SafetyRatingDict]:
|
| 183 |
+
result = []
|
| 184 |
+
for r in ratings:
|
| 185 |
+
result.append(convert_rating_to_enum(r))
|
| 186 |
+
return result
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
class SafetySettingDict(TypedDict):
|
| 190 |
+
category: protos.HarmCategory
|
| 191 |
+
threshold: HarmBlockThreshold
|
| 192 |
+
|
| 193 |
+
__doc__ = string_utils.strip_oneof(protos.SafetySetting.__doc__)
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
class LooseSafetySettingDict(TypedDict):
|
| 197 |
+
category: HarmCategoryOptions
|
| 198 |
+
threshold: HarmBlockThresholdOptions
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
EasySafetySetting = Mapping[HarmCategoryOptions, HarmBlockThresholdOptions]
|
| 202 |
+
EasySafetySettingDict = dict[HarmCategoryOptions, HarmBlockThresholdOptions]
|
| 203 |
+
|
| 204 |
+
SafetySettingOptions = Union[
|
| 205 |
+
HarmBlockThresholdOptions, EasySafetySetting, Iterable[LooseSafetySettingDict], None
|
| 206 |
+
]
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
def _expand_block_threshold(block_threshold: HarmBlockThresholdOptions):
|
| 210 |
+
block_threshold = to_block_threshold(block_threshold)
|
| 211 |
+
hc = set(_HARM_CATEGORIES.values())
|
| 212 |
+
hc.remove(protos.HarmCategory.HARM_CATEGORY_UNSPECIFIED)
|
| 213 |
+
return {category: block_threshold for category in hc}
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
def to_easy_safety_dict(settings: SafetySettingOptions) -> EasySafetySettingDict:
|
| 217 |
+
if settings is None:
|
| 218 |
+
return {}
|
| 219 |
+
|
| 220 |
+
if isinstance(settings, (int, str, HarmBlockThreshold)):
|
| 221 |
+
settings = _expand_block_threshold(settings)
|
| 222 |
+
|
| 223 |
+
if isinstance(settings, Mapping):
|
| 224 |
+
return {to_harm_category(key): to_block_threshold(value) for key, value in settings.items()}
|
| 225 |
+
|
| 226 |
+
else: # Iterable
|
| 227 |
+
result = {}
|
| 228 |
+
for setting in settings:
|
| 229 |
+
if isinstance(setting, protos.SafetySetting):
|
| 230 |
+
result[to_harm_category(setting.category)] = to_block_threshold(setting.threshold)
|
| 231 |
+
elif isinstance(setting, dict):
|
| 232 |
+
result[to_harm_category(setting["category"])] = to_block_threshold(
|
| 233 |
+
setting["threshold"]
|
| 234 |
+
)
|
| 235 |
+
else:
|
| 236 |
+
raise ValueError(
|
| 237 |
+
f"Could not understand safety setting:\n {type(setting)=}\n {setting=}"
|
| 238 |
+
)
|
| 239 |
+
return result
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
def normalize_safety_settings(
|
| 243 |
+
settings: SafetySettingOptions,
|
| 244 |
+
) -> list[SafetySettingDict] | None:
|
| 245 |
+
if settings is None:
|
| 246 |
+
return None
|
| 247 |
+
|
| 248 |
+
if isinstance(settings, (int, str, HarmBlockThreshold)):
|
| 249 |
+
settings = _expand_block_threshold(settings)
|
| 250 |
+
|
| 251 |
+
if isinstance(settings, Mapping):
|
| 252 |
+
return [
|
| 253 |
+
{
|
| 254 |
+
"category": to_harm_category(key),
|
| 255 |
+
"threshold": to_block_threshold(value),
|
| 256 |
+
}
|
| 257 |
+
for key, value in settings.items()
|
| 258 |
+
]
|
| 259 |
+
else:
|
| 260 |
+
return [
|
| 261 |
+
{
|
| 262 |
+
"category": to_harm_category(d["category"]),
|
| 263 |
+
"threshold": to_block_threshold(d["threshold"]),
|
| 264 |
+
}
|
| 265 |
+
for d in settings
|
| 266 |
+
]
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
def convert_setting_to_enum(setting: dict) -> SafetySettingDict:
|
| 270 |
+
return {
|
| 271 |
+
"category": protos.HarmCategory(setting["category"]),
|
| 272 |
+
"threshold": HarmBlockThreshold(setting["threshold"]),
|
| 273 |
+
}
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
class SafetyFeedbackDict(TypedDict):
|
| 277 |
+
rating: SafetyRatingDict
|
| 278 |
+
setting: SafetySettingDict
|
| 279 |
+
|
| 280 |
+
__doc__ = string_utils.strip_oneof(protos.SafetyFeedback.__doc__)
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
def convert_safety_feedback_to_enums(
|
| 284 |
+
safety_feedback: Iterable[dict],
|
| 285 |
+
) -> List[SafetyFeedbackDict]:
|
| 286 |
+
result = []
|
| 287 |
+
for sf in safety_feedback:
|
| 288 |
+
result.append(
|
| 289 |
+
{
|
| 290 |
+
"rating": convert_rating_to_enum(sf["rating"]),
|
| 291 |
+
"setting": convert_setting_to_enum(sf["setting"]),
|
| 292 |
+
}
|
| 293 |
+
)
|
| 294 |
+
return result
|
| 295 |
+
|
| 296 |
+
|
| 297 |
+
def convert_candidate_enums(candidates):
|
| 298 |
+
result = []
|
| 299 |
+
for candidate in candidates:
|
| 300 |
+
candidate = candidate.copy()
|
| 301 |
+
candidate["safety_ratings"] = convert_ratings_to_enum(candidate["safety_ratings"])
|
| 302 |
+
result.append(candidate)
|
| 303 |
+
return result
|
.venv/lib/python3.11/site-packages/google/generativeai/types/text_types.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Copyright 2023 Google LLC
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
from __future__ import annotations
|
| 16 |
+
|
| 17 |
+
import sys
|
| 18 |
+
import abc
|
| 19 |
+
import dataclasses
|
| 20 |
+
from typing import Any, Dict, List
|
| 21 |
+
from typing_extensions import TypedDict
|
| 22 |
+
|
| 23 |
+
from google.generativeai import string_utils
|
| 24 |
+
from google.generativeai.types import citation_types
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class EmbeddingDict(TypedDict):
|
| 28 |
+
embedding: list[float]
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class BatchEmbeddingDict(TypedDict):
|
| 32 |
+
embedding: list[list[float]]
|
.venv/lib/python3.11/site-packages/google/logging/type/__pycache__/http_request_pb2.cpython-311.pyc
ADDED
|
Binary file (2.28 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/google/logging/type/__pycache__/log_severity_pb2.cpython-311.pyc
ADDED
|
Binary file (1.94 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/google/logging/type/http_request.proto
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright 2024 Google LLC
|
| 2 |
+
//
|
| 3 |
+
// Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
// you may not use this file except in compliance with the License.
|
| 5 |
+
// You may obtain a copy of the License at
|
| 6 |
+
//
|
| 7 |
+
// http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
//
|
| 9 |
+
// Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
// distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
// See the License for the specific language governing permissions and
|
| 13 |
+
// limitations under the License.
|
| 14 |
+
|
| 15 |
+
syntax = "proto3";
|
| 16 |
+
|
| 17 |
+
package google.logging.type;
|
| 18 |
+
|
| 19 |
+
import "google/protobuf/duration.proto";
|
| 20 |
+
|
| 21 |
+
option csharp_namespace = "Google.Cloud.Logging.Type";
|
| 22 |
+
option go_package = "google.golang.org/genproto/googleapis/logging/type;ltype";
|
| 23 |
+
option java_multiple_files = true;
|
| 24 |
+
option java_outer_classname = "HttpRequestProto";
|
| 25 |
+
option java_package = "com.google.logging.type";
|
| 26 |
+
option php_namespace = "Google\\Cloud\\Logging\\Type";
|
| 27 |
+
option ruby_package = "Google::Cloud::Logging::Type";
|
| 28 |
+
|
| 29 |
+
// A common proto for logging HTTP requests. Only contains semantics
|
| 30 |
+
// defined by the HTTP specification. Product-specific logging
|
| 31 |
+
// information MUST be defined in a separate message.
|
| 32 |
+
message HttpRequest {
|
| 33 |
+
// The request method. Examples: `"GET"`, `"HEAD"`, `"PUT"`, `"POST"`.
|
| 34 |
+
string request_method = 1;
|
| 35 |
+
|
| 36 |
+
// The scheme (http, https), the host name, the path and the query
|
| 37 |
+
// portion of the URL that was requested.
|
| 38 |
+
// Example: `"http://example.com/some/info?color=red"`.
|
| 39 |
+
string request_url = 2;
|
| 40 |
+
|
| 41 |
+
// The size of the HTTP request message in bytes, including the request
|
| 42 |
+
// headers and the request body.
|
| 43 |
+
int64 request_size = 3;
|
| 44 |
+
|
| 45 |
+
// The response code indicating the status of response.
|
| 46 |
+
// Examples: 200, 404.
|
| 47 |
+
int32 status = 4;
|
| 48 |
+
|
| 49 |
+
// The size of the HTTP response message sent back to the client, in bytes,
|
| 50 |
+
// including the response headers and the response body.
|
| 51 |
+
int64 response_size = 5;
|
| 52 |
+
|
| 53 |
+
// The user agent sent by the client. Example:
|
| 54 |
+
// `"Mozilla/4.0 (compatible; MSIE 6.0; Windows 98; Q312461; .NET
|
| 55 |
+
// CLR 1.0.3705)"`.
|
| 56 |
+
string user_agent = 6;
|
| 57 |
+
|
| 58 |
+
// The IP address (IPv4 or IPv6) of the client that issued the HTTP
|
| 59 |
+
// request. This field can include port information. Examples:
|
| 60 |
+
// `"192.168.1.1"`, `"10.0.0.1:80"`, `"FE80::0202:B3FF:FE1E:8329"`.
|
| 61 |
+
string remote_ip = 7;
|
| 62 |
+
|
| 63 |
+
// The IP address (IPv4 or IPv6) of the origin server that the request was
|
| 64 |
+
// sent to. This field can include port information. Examples:
|
| 65 |
+
// `"192.168.1.1"`, `"10.0.0.1:80"`, `"FE80::0202:B3FF:FE1E:8329"`.
|
| 66 |
+
string server_ip = 13;
|
| 67 |
+
|
| 68 |
+
// The referer URL of the request, as defined in
|
| 69 |
+
// [HTTP/1.1 Header Field
|
| 70 |
+
// Definitions](https://datatracker.ietf.org/doc/html/rfc2616#section-14.36).
|
| 71 |
+
string referer = 8;
|
| 72 |
+
|
| 73 |
+
// The request processing latency on the server, from the time the request was
|
| 74 |
+
// received until the response was sent.
|
| 75 |
+
google.protobuf.Duration latency = 14;
|
| 76 |
+
|
| 77 |
+
// Whether or not a cache lookup was attempted.
|
| 78 |
+
bool cache_lookup = 11;
|
| 79 |
+
|
| 80 |
+
// Whether or not an entity was served from cache
|
| 81 |
+
// (with or without validation).
|
| 82 |
+
bool cache_hit = 9;
|
| 83 |
+
|
| 84 |
+
// Whether or not the response was validated with the origin server before
|
| 85 |
+
// being served from cache. This field is only meaningful if `cache_hit` is
|
| 86 |
+
// True.
|
| 87 |
+
bool cache_validated_with_origin_server = 10;
|
| 88 |
+
|
| 89 |
+
// The number of HTTP response bytes inserted into cache. Set only when a
|
| 90 |
+
// cache fill was attempted.
|
| 91 |
+
int64 cache_fill_bytes = 12;
|
| 92 |
+
|
| 93 |
+
// Protocol used for the request. Examples: "HTTP/1.1", "HTTP/2", "websocket"
|
| 94 |
+
string protocol = 15;
|
| 95 |
+
}
|
.venv/lib/python3.11/site-packages/google/logging/type/log_severity.proto
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright 2024 Google LLC
|
| 2 |
+
//
|
| 3 |
+
// Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
// you may not use this file except in compliance with the License.
|
| 5 |
+
// You may obtain a copy of the License at
|
| 6 |
+
//
|
| 7 |
+
// http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
//
|
| 9 |
+
// Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
// distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
// See the License for the specific language governing permissions and
|
| 13 |
+
// limitations under the License.
|
| 14 |
+
|
| 15 |
+
syntax = "proto3";
|
| 16 |
+
|
| 17 |
+
package google.logging.type;
|
| 18 |
+
|
| 19 |
+
option csharp_namespace = "Google.Cloud.Logging.Type";
|
| 20 |
+
option go_package = "google.golang.org/genproto/googleapis/logging/type;ltype";
|
| 21 |
+
option java_multiple_files = true;
|
| 22 |
+
option java_outer_classname = "LogSeverityProto";
|
| 23 |
+
option java_package = "com.google.logging.type";
|
| 24 |
+
option objc_class_prefix = "GLOG";
|
| 25 |
+
option php_namespace = "Google\\Cloud\\Logging\\Type";
|
| 26 |
+
option ruby_package = "Google::Cloud::Logging::Type";
|
| 27 |
+
|
| 28 |
+
// The severity of the event described in a log entry, expressed as one of the
|
| 29 |
+
// standard severity levels listed below. For your reference, the levels are
|
| 30 |
+
// assigned the listed numeric values. The effect of using numeric values other
|
| 31 |
+
// than those listed is undefined.
|
| 32 |
+
//
|
| 33 |
+
// You can filter for log entries by severity. For example, the following
|
| 34 |
+
// filter expression will match log entries with severities `INFO`, `NOTICE`,
|
| 35 |
+
// and `WARNING`:
|
| 36 |
+
//
|
| 37 |
+
// severity > DEBUG AND severity <= WARNING
|
| 38 |
+
//
|
| 39 |
+
// If you are writing log entries, you should map other severity encodings to
|
| 40 |
+
// one of these standard levels. For example, you might map all of Java's FINE,
|
| 41 |
+
// FINER, and FINEST levels to `LogSeverity.DEBUG`. You can preserve the
|
| 42 |
+
// original severity level in the log entry payload if you wish.
|
| 43 |
+
enum LogSeverity {
|
| 44 |
+
// (0) The log entry has no assigned severity level.
|
| 45 |
+
DEFAULT = 0;
|
| 46 |
+
|
| 47 |
+
// (100) Debug or trace information.
|
| 48 |
+
DEBUG = 100;
|
| 49 |
+
|
| 50 |
+
// (200) Routine information, such as ongoing status or performance.
|
| 51 |
+
INFO = 200;
|
| 52 |
+
|
| 53 |
+
// (300) Normal but significant events, such as start up, shut down, or
|
| 54 |
+
// a configuration change.
|
| 55 |
+
NOTICE = 300;
|
| 56 |
+
|
| 57 |
+
// (400) Warning events might cause problems.
|
| 58 |
+
WARNING = 400;
|
| 59 |
+
|
| 60 |
+
// (500) Error events are likely to cause problems.
|
| 61 |
+
ERROR = 500;
|
| 62 |
+
|
| 63 |
+
// (600) Critical events cause more severe problems or outages.
|
| 64 |
+
CRITICAL = 600;
|
| 65 |
+
|
| 66 |
+
// (700) A person must take an action immediately.
|
| 67 |
+
ALERT = 700;
|
| 68 |
+
|
| 69 |
+
// (800) One or more systems are unusable.
|
| 70 |
+
EMERGENCY = 800;
|
| 71 |
+
}
|
.venv/lib/python3.11/site-packages/google/logging/type/log_severity_pb2.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
|
| 3 |
+
# Copyright 2024 Google LLC
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
|
| 17 |
+
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
| 18 |
+
# source: google/logging/type/log_severity.proto
|
| 19 |
+
"""Generated protocol buffer code."""
|
| 20 |
+
from google.protobuf import descriptor as _descriptor
|
| 21 |
+
from google.protobuf import descriptor_pool as _descriptor_pool
|
| 22 |
+
from google.protobuf import symbol_database as _symbol_database
|
| 23 |
+
from google.protobuf.internal import builder as _builder
|
| 24 |
+
|
| 25 |
+
# @@protoc_insertion_point(imports)
|
| 26 |
+
|
| 27 |
+
_sym_db = _symbol_database.Default()
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
|
| 31 |
+
b"\n&google/logging/type/log_severity.proto\x12\x13google.logging.type*\x82\x01\n\x0bLogSeverity\x12\x0b\n\x07\x44\x45\x46\x41ULT\x10\x00\x12\t\n\x05\x44\x45\x42UG\x10\x64\x12\t\n\x04INFO\x10\xc8\x01\x12\x0b\n\x06NOTICE\x10\xac\x02\x12\x0c\n\x07WARNING\x10\x90\x03\x12\n\n\x05\x45RROR\x10\xf4\x03\x12\r\n\x08\x43RITICAL\x10\xd8\x04\x12\n\n\x05\x41LERT\x10\xbc\x05\x12\x0e\n\tEMERGENCY\x10\xa0\x06\x42\xc5\x01\n\x17\x63om.google.logging.typeB\x10LogSeverityProtoP\x01Z8google.golang.org/genproto/googleapis/logging/type;ltype\xa2\x02\x04GLOG\xaa\x02\x19Google.Cloud.Logging.Type\xca\x02\x19Google\\Cloud\\Logging\\Type\xea\x02\x1cGoogle::Cloud::Logging::Typeb\x06proto3"
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
_globals = globals()
|
| 35 |
+
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
| 36 |
+
_builder.BuildTopDescriptorsAndMessages(
|
| 37 |
+
DESCRIPTOR, "google.logging.type.log_severity_pb2", _globals
|
| 38 |
+
)
|
| 39 |
+
if _descriptor._USE_C_DESCRIPTORS == False:
|
| 40 |
+
DESCRIPTOR._options = None
|
| 41 |
+
DESCRIPTOR._serialized_options = b"\n\027com.google.logging.typeB\020LogSeverityProtoP\001Z8google.golang.org/genproto/googleapis/logging/type;ltype\242\002\004GLOG\252\002\031Google.Cloud.Logging.Type\312\002\031Google\\Cloud\\Logging\\Type\352\002\034Google::Cloud::Logging::Type"
|
| 42 |
+
_globals["_LOGSEVERITY"]._serialized_start = 64
|
| 43 |
+
_globals["_LOGSEVERITY"]._serialized_end = 194
|
| 44 |
+
# @@protoc_insertion_point(module_scope)
|
.venv/lib/python3.11/site-packages/google/protobuf/__init__.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Protocol Buffers - Google's data interchange format
|
| 2 |
+
# Copyright 2008 Google Inc. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Use of this source code is governed by a BSD-style
|
| 5 |
+
# license that can be found in the LICENSE file or at
|
| 6 |
+
# https://developers.google.com/open-source/licenses/bsd
|
| 7 |
+
|
| 8 |
+
# Copyright 2007 Google Inc. All Rights Reserved.
|
| 9 |
+
|
| 10 |
+
__version__ = '5.29.3'
|
.venv/lib/python3.11/site-packages/google/protobuf/any.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Protocol Buffers - Google's data interchange format
|
| 2 |
+
# Copyright 2008 Google Inc. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Use of this source code is governed by a BSD-style
|
| 5 |
+
# license that can be found in the LICENSE file or at
|
| 6 |
+
# https://developers.google.com/open-source/licenses/bsd
|
| 7 |
+
|
| 8 |
+
"""Contains the Any helper APIs."""
|
| 9 |
+
|
| 10 |
+
from typing import Optional
|
| 11 |
+
|
| 12 |
+
from google.protobuf import descriptor
|
| 13 |
+
from google.protobuf.message import Message
|
| 14 |
+
|
| 15 |
+
from google.protobuf.any_pb2 import Any
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def pack(
|
| 19 |
+
msg: Message,
|
| 20 |
+
type_url_prefix: Optional[str] = 'type.googleapis.com/',
|
| 21 |
+
deterministic: Optional[bool] = None,
|
| 22 |
+
) -> Any:
|
| 23 |
+
any_msg = Any()
|
| 24 |
+
any_msg.Pack(
|
| 25 |
+
msg=msg, type_url_prefix=type_url_prefix, deterministic=deterministic
|
| 26 |
+
)
|
| 27 |
+
return any_msg
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def unpack(any_msg: Any, msg: Message) -> bool:
|
| 31 |
+
return any_msg.Unpack(msg=msg)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def type_name(any_msg: Any) -> str:
|
| 35 |
+
return any_msg.TypeName()
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def is_type(any_msg: Any, des: descriptor.Descriptor) -> bool:
|
| 39 |
+
return any_msg.Is(des)
|
.venv/lib/python3.11/site-packages/google/protobuf/any_pb2.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
| 3 |
+
# NO CHECKED-IN PROTOBUF GENCODE
|
| 4 |
+
# source: google/protobuf/any.proto
|
| 5 |
+
# Protobuf Python Version: 5.29.3
|
| 6 |
+
"""Generated protocol buffer code."""
|
| 7 |
+
from google.protobuf import descriptor as _descriptor
|
| 8 |
+
from google.protobuf import descriptor_pool as _descriptor_pool
|
| 9 |
+
from google.protobuf import runtime_version as _runtime_version
|
| 10 |
+
from google.protobuf import symbol_database as _symbol_database
|
| 11 |
+
from google.protobuf.internal import builder as _builder
|
| 12 |
+
_runtime_version.ValidateProtobufRuntimeVersion(
|
| 13 |
+
_runtime_version.Domain.PUBLIC,
|
| 14 |
+
5,
|
| 15 |
+
29,
|
| 16 |
+
3,
|
| 17 |
+
'',
|
| 18 |
+
'google/protobuf/any.proto'
|
| 19 |
+
)
|
| 20 |
+
# @@protoc_insertion_point(imports)
|
| 21 |
+
|
| 22 |
+
_sym_db = _symbol_database.Default()
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x19google/protobuf/any.proto\x12\x0fgoogle.protobuf\"6\n\x03\x41ny\x12\x19\n\x08type_url\x18\x01 \x01(\tR\x07typeUrl\x12\x14\n\x05value\x18\x02 \x01(\x0cR\x05valueBv\n\x13\x63om.google.protobufB\x08\x41nyProtoP\x01Z,google.golang.org/protobuf/types/known/anypb\xa2\x02\x03GPB\xaa\x02\x1eGoogle.Protobuf.WellKnownTypesb\x06proto3')
|
| 28 |
+
|
| 29 |
+
_globals = globals()
|
| 30 |
+
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
| 31 |
+
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'google.protobuf.any_pb2', _globals)
|
| 32 |
+
if not _descriptor._USE_C_DESCRIPTORS:
|
| 33 |
+
_globals['DESCRIPTOR']._loaded_options = None
|
| 34 |
+
_globals['DESCRIPTOR']._serialized_options = b'\n\023com.google.protobufB\010AnyProtoP\001Z,google.golang.org/protobuf/types/known/anypb\242\002\003GPB\252\002\036Google.Protobuf.WellKnownTypes'
|
| 35 |
+
_globals['_ANY']._serialized_start=46
|
| 36 |
+
_globals['_ANY']._serialized_end=100
|
| 37 |
+
# @@protoc_insertion_point(module_scope)
|
.venv/lib/python3.11/site-packages/google/protobuf/api_pb2.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
| 3 |
+
# NO CHECKED-IN PROTOBUF GENCODE
|
| 4 |
+
# source: google/protobuf/api.proto
|
| 5 |
+
# Protobuf Python Version: 5.29.3
|
| 6 |
+
"""Generated protocol buffer code."""
|
| 7 |
+
from google.protobuf import descriptor as _descriptor
|
| 8 |
+
from google.protobuf import descriptor_pool as _descriptor_pool
|
| 9 |
+
from google.protobuf import runtime_version as _runtime_version
|
| 10 |
+
from google.protobuf import symbol_database as _symbol_database
|
| 11 |
+
from google.protobuf.internal import builder as _builder
|
| 12 |
+
_runtime_version.ValidateProtobufRuntimeVersion(
|
| 13 |
+
_runtime_version.Domain.PUBLIC,
|
| 14 |
+
5,
|
| 15 |
+
29,
|
| 16 |
+
3,
|
| 17 |
+
'',
|
| 18 |
+
'google/protobuf/api.proto'
|
| 19 |
+
)
|
| 20 |
+
# @@protoc_insertion_point(imports)
|
| 21 |
+
|
| 22 |
+
_sym_db = _symbol_database.Default()
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
from google.protobuf import source_context_pb2 as google_dot_protobuf_dot_source__context__pb2
|
| 26 |
+
from google.protobuf import type_pb2 as google_dot_protobuf_dot_type__pb2
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x19google/protobuf/api.proto\x12\x0fgoogle.protobuf\x1a$google/protobuf/source_context.proto\x1a\x1agoogle/protobuf/type.proto\"\xc1\x02\n\x03\x41pi\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x31\n\x07methods\x18\x02 \x03(\x0b\x32\x17.google.protobuf.MethodR\x07methods\x12\x31\n\x07options\x18\x03 \x03(\x0b\x32\x17.google.protobuf.OptionR\x07options\x12\x18\n\x07version\x18\x04 \x01(\tR\x07version\x12\x45\n\x0esource_context\x18\x05 \x01(\x0b\x32\x1e.google.protobuf.SourceContextR\rsourceContext\x12.\n\x06mixins\x18\x06 \x03(\x0b\x32\x16.google.protobuf.MixinR\x06mixins\x12/\n\x06syntax\x18\x07 \x01(\x0e\x32\x17.google.protobuf.SyntaxR\x06syntax\"\xb2\x02\n\x06Method\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12(\n\x10request_type_url\x18\x02 \x01(\tR\x0erequestTypeUrl\x12+\n\x11request_streaming\x18\x03 \x01(\x08R\x10requestStreaming\x12*\n\x11response_type_url\x18\x04 \x01(\tR\x0fresponseTypeUrl\x12-\n\x12response_streaming\x18\x05 \x01(\x08R\x11responseStreaming\x12\x31\n\x07options\x18\x06 \x03(\x0b\x32\x17.google.protobuf.OptionR\x07options\x12/\n\x06syntax\x18\x07 \x01(\x0e\x32\x17.google.protobuf.SyntaxR\x06syntax\"/\n\x05Mixin\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x12\n\x04root\x18\x02 \x01(\tR\x04rootBv\n\x13\x63om.google.protobufB\x08\x41piProtoP\x01Z,google.golang.org/protobuf/types/known/apipb\xa2\x02\x03GPB\xaa\x02\x1eGoogle.Protobuf.WellKnownTypesb\x06proto3')
|
| 30 |
+
|
| 31 |
+
_globals = globals()
|
| 32 |
+
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
| 33 |
+
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'google.protobuf.api_pb2', _globals)
|
| 34 |
+
if not _descriptor._USE_C_DESCRIPTORS:
|
| 35 |
+
_globals['DESCRIPTOR']._loaded_options = None
|
| 36 |
+
_globals['DESCRIPTOR']._serialized_options = b'\n\023com.google.protobufB\010ApiProtoP\001Z,google.golang.org/protobuf/types/known/apipb\242\002\003GPB\252\002\036Google.Protobuf.WellKnownTypes'
|
| 37 |
+
_globals['_API']._serialized_start=113
|
| 38 |
+
_globals['_API']._serialized_end=434
|
| 39 |
+
_globals['_METHOD']._serialized_start=437
|
| 40 |
+
_globals['_METHOD']._serialized_end=743
|
| 41 |
+
_globals['_MIXIN']._serialized_start=745
|
| 42 |
+
_globals['_MIXIN']._serialized_end=792
|
| 43 |
+
# @@protoc_insertion_point(module_scope)
|
.venv/lib/python3.11/site-packages/google/protobuf/descriptor.py
ADDED
|
@@ -0,0 +1,1511 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Protocol Buffers - Google's data interchange format
|
| 2 |
+
# Copyright 2008 Google Inc. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Use of this source code is governed by a BSD-style
|
| 5 |
+
# license that can be found in the LICENSE file or at
|
| 6 |
+
# https://developers.google.com/open-source/licenses/bsd
|
| 7 |
+
|
| 8 |
+
"""Descriptors essentially contain exactly the information found in a .proto
|
| 9 |
+
file, in types that make this information accessible in Python.
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
__author__ = 'robinson@google.com (Will Robinson)'
|
| 13 |
+
|
| 14 |
+
import abc
|
| 15 |
+
import binascii
|
| 16 |
+
import os
|
| 17 |
+
import threading
|
| 18 |
+
import warnings
|
| 19 |
+
|
| 20 |
+
from google.protobuf.internal import api_implementation
|
| 21 |
+
|
| 22 |
+
_USE_C_DESCRIPTORS = False
|
| 23 |
+
if api_implementation.Type() != 'python':
|
| 24 |
+
# pylint: disable=protected-access
|
| 25 |
+
_message = api_implementation._c_module
|
| 26 |
+
# TODO: Remove this import after fix api_implementation
|
| 27 |
+
if _message is None:
|
| 28 |
+
from google.protobuf.pyext import _message
|
| 29 |
+
_USE_C_DESCRIPTORS = True
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class Error(Exception):
|
| 33 |
+
"""Base error for this module."""
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class TypeTransformationError(Error):
|
| 37 |
+
"""Error transforming between python proto type and corresponding C++ type."""
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
if _USE_C_DESCRIPTORS:
|
| 41 |
+
# This metaclass allows to override the behavior of code like
|
| 42 |
+
# isinstance(my_descriptor, FieldDescriptor)
|
| 43 |
+
# and make it return True when the descriptor is an instance of the extension
|
| 44 |
+
# type written in C++.
|
| 45 |
+
class DescriptorMetaclass(type):
|
| 46 |
+
|
| 47 |
+
def __instancecheck__(cls, obj):
|
| 48 |
+
if super(DescriptorMetaclass, cls).__instancecheck__(obj):
|
| 49 |
+
return True
|
| 50 |
+
if isinstance(obj, cls._C_DESCRIPTOR_CLASS):
|
| 51 |
+
return True
|
| 52 |
+
return False
|
| 53 |
+
else:
|
| 54 |
+
# The standard metaclass; nothing changes.
|
| 55 |
+
DescriptorMetaclass = abc.ABCMeta
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class _Lock(object):
|
| 59 |
+
"""Wrapper class of threading.Lock(), which is allowed by 'with'."""
|
| 60 |
+
|
| 61 |
+
def __new__(cls):
|
| 62 |
+
self = object.__new__(cls)
|
| 63 |
+
self._lock = threading.Lock() # pylint: disable=protected-access
|
| 64 |
+
return self
|
| 65 |
+
|
| 66 |
+
def __enter__(self):
|
| 67 |
+
self._lock.acquire()
|
| 68 |
+
|
| 69 |
+
def __exit__(self, exc_type, exc_value, exc_tb):
|
| 70 |
+
self._lock.release()
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
_lock = threading.Lock()
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def _Deprecated(name):
|
| 77 |
+
if _Deprecated.count > 0:
|
| 78 |
+
_Deprecated.count -= 1
|
| 79 |
+
warnings.warn(
|
| 80 |
+
'Call to deprecated create function %s(). Note: Create unlinked '
|
| 81 |
+
'descriptors is going to go away. Please use get/find descriptors from '
|
| 82 |
+
'generated code or query the descriptor_pool.'
|
| 83 |
+
% name,
|
| 84 |
+
category=DeprecationWarning, stacklevel=3)
|
| 85 |
+
|
| 86 |
+
# These must match the values in descriptor.proto, but we can't use them
|
| 87 |
+
# directly because we sometimes need to reference them in feature helpers
|
| 88 |
+
# below *during* the build of descriptor.proto.
|
| 89 |
+
_FEATURESET_MESSAGE_ENCODING_DELIMITED = 2
|
| 90 |
+
_FEATURESET_FIELD_PRESENCE_IMPLICIT = 2
|
| 91 |
+
_FEATURESET_FIELD_PRESENCE_LEGACY_REQUIRED = 3
|
| 92 |
+
_FEATURESET_REPEATED_FIELD_ENCODING_PACKED = 1
|
| 93 |
+
_FEATURESET_ENUM_TYPE_CLOSED = 2
|
| 94 |
+
|
| 95 |
+
# Deprecated warnings will print 100 times at most which should be enough for
|
| 96 |
+
# users to notice and do not cause timeout.
|
| 97 |
+
_Deprecated.count = 100
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
_internal_create_key = object()
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
class DescriptorBase(metaclass=DescriptorMetaclass):
|
| 104 |
+
|
| 105 |
+
"""Descriptors base class.
|
| 106 |
+
|
| 107 |
+
This class is the base of all descriptor classes. It provides common options
|
| 108 |
+
related functionality.
|
| 109 |
+
|
| 110 |
+
Attributes:
|
| 111 |
+
has_options: True if the descriptor has non-default options. Usually it is
|
| 112 |
+
not necessary to read this -- just call GetOptions() which will happily
|
| 113 |
+
return the default instance. However, it's sometimes useful for
|
| 114 |
+
efficiency, and also useful inside the protobuf implementation to avoid
|
| 115 |
+
some bootstrapping issues.
|
| 116 |
+
file (FileDescriptor): Reference to file info.
|
| 117 |
+
"""
|
| 118 |
+
|
| 119 |
+
if _USE_C_DESCRIPTORS:
|
| 120 |
+
# The class, or tuple of classes, that are considered as "virtual
|
| 121 |
+
# subclasses" of this descriptor class.
|
| 122 |
+
_C_DESCRIPTOR_CLASS = ()
|
| 123 |
+
|
| 124 |
+
def __init__(self, file, options, serialized_options, options_class_name):
|
| 125 |
+
"""Initialize the descriptor given its options message and the name of the
|
| 126 |
+
class of the options message. The name of the class is required in case
|
| 127 |
+
the options message is None and has to be created.
|
| 128 |
+
"""
|
| 129 |
+
self._features = None
|
| 130 |
+
self.file = file
|
| 131 |
+
self._options = options
|
| 132 |
+
self._loaded_options = None
|
| 133 |
+
self._options_class_name = options_class_name
|
| 134 |
+
self._serialized_options = serialized_options
|
| 135 |
+
|
| 136 |
+
# Does this descriptor have non-default options?
|
| 137 |
+
self.has_options = (self._options is not None) or (
|
| 138 |
+
self._serialized_options is not None
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
@property
|
| 142 |
+
@abc.abstractmethod
|
| 143 |
+
def _parent(self):
|
| 144 |
+
pass
|
| 145 |
+
|
| 146 |
+
def _InferLegacyFeatures(self, edition, options, features):
|
| 147 |
+
"""Infers features from proto2/proto3 syntax so that editions logic can be used everywhere.
|
| 148 |
+
|
| 149 |
+
Args:
|
| 150 |
+
edition: The edition to infer features for.
|
| 151 |
+
options: The options for this descriptor that are being processed.
|
| 152 |
+
features: The feature set object to modify with inferred features.
|
| 153 |
+
"""
|
| 154 |
+
pass
|
| 155 |
+
|
| 156 |
+
def _GetFeatures(self):
|
| 157 |
+
if not self._features:
|
| 158 |
+
self._LazyLoadOptions()
|
| 159 |
+
return self._features
|
| 160 |
+
|
| 161 |
+
def _ResolveFeatures(self, edition, raw_options):
|
| 162 |
+
"""Resolves features from the raw options of this descriptor.
|
| 163 |
+
|
| 164 |
+
Args:
|
| 165 |
+
edition: The edition to use for feature defaults.
|
| 166 |
+
raw_options: The options for this descriptor that are being processed.
|
| 167 |
+
|
| 168 |
+
Returns:
|
| 169 |
+
A fully resolved feature set for making runtime decisions.
|
| 170 |
+
"""
|
| 171 |
+
# pylint: disable=g-import-not-at-top
|
| 172 |
+
from google.protobuf import descriptor_pb2
|
| 173 |
+
|
| 174 |
+
if self._parent:
|
| 175 |
+
features = descriptor_pb2.FeatureSet()
|
| 176 |
+
features.CopyFrom(self._parent._GetFeatures())
|
| 177 |
+
else:
|
| 178 |
+
features = self.file.pool._CreateDefaultFeatures(edition)
|
| 179 |
+
unresolved = descriptor_pb2.FeatureSet()
|
| 180 |
+
unresolved.CopyFrom(raw_options.features)
|
| 181 |
+
self._InferLegacyFeatures(edition, raw_options, unresolved)
|
| 182 |
+
features.MergeFrom(unresolved)
|
| 183 |
+
|
| 184 |
+
# Use the feature cache to reduce memory bloat.
|
| 185 |
+
return self.file.pool._InternFeatures(features)
|
| 186 |
+
|
| 187 |
+
def _LazyLoadOptions(self):
|
| 188 |
+
"""Lazily initializes descriptor options towards the end of the build."""
|
| 189 |
+
if self._loaded_options:
|
| 190 |
+
return
|
| 191 |
+
|
| 192 |
+
# pylint: disable=g-import-not-at-top
|
| 193 |
+
from google.protobuf import descriptor_pb2
|
| 194 |
+
|
| 195 |
+
if not hasattr(descriptor_pb2, self._options_class_name):
|
| 196 |
+
raise RuntimeError(
|
| 197 |
+
'Unknown options class name %s!' % self._options_class_name
|
| 198 |
+
)
|
| 199 |
+
options_class = getattr(descriptor_pb2, self._options_class_name)
|
| 200 |
+
features = None
|
| 201 |
+
edition = self.file._edition
|
| 202 |
+
|
| 203 |
+
if not self.has_options:
|
| 204 |
+
if not self._features:
|
| 205 |
+
features = self._ResolveFeatures(
|
| 206 |
+
descriptor_pb2.Edition.Value(edition), options_class()
|
| 207 |
+
)
|
| 208 |
+
with _lock:
|
| 209 |
+
self._loaded_options = options_class()
|
| 210 |
+
if not self._features:
|
| 211 |
+
self._features = features
|
| 212 |
+
else:
|
| 213 |
+
if not self._serialized_options:
|
| 214 |
+
options = self._options
|
| 215 |
+
else:
|
| 216 |
+
options = _ParseOptions(options_class(), self._serialized_options)
|
| 217 |
+
|
| 218 |
+
if not self._features:
|
| 219 |
+
features = self._ResolveFeatures(
|
| 220 |
+
descriptor_pb2.Edition.Value(edition), options
|
| 221 |
+
)
|
| 222 |
+
with _lock:
|
| 223 |
+
self._loaded_options = options
|
| 224 |
+
if not self._features:
|
| 225 |
+
self._features = features
|
| 226 |
+
if options.HasField('features'):
|
| 227 |
+
options.ClearField('features')
|
| 228 |
+
if not options.SerializeToString():
|
| 229 |
+
self._loaded_options = options_class()
|
| 230 |
+
self.has_options = False
|
| 231 |
+
|
| 232 |
+
def GetOptions(self):
|
| 233 |
+
"""Retrieves descriptor options.
|
| 234 |
+
|
| 235 |
+
Returns:
|
| 236 |
+
The options set on this descriptor.
|
| 237 |
+
"""
|
| 238 |
+
if not self._loaded_options:
|
| 239 |
+
self._LazyLoadOptions()
|
| 240 |
+
return self._loaded_options
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
class _NestedDescriptorBase(DescriptorBase):
|
| 244 |
+
"""Common class for descriptors that can be nested."""
|
| 245 |
+
|
| 246 |
+
def __init__(self, options, options_class_name, name, full_name,
|
| 247 |
+
file, containing_type, serialized_start=None,
|
| 248 |
+
serialized_end=None, serialized_options=None):
|
| 249 |
+
"""Constructor.
|
| 250 |
+
|
| 251 |
+
Args:
|
| 252 |
+
options: Protocol message options or None to use default message options.
|
| 253 |
+
options_class_name (str): The class name of the above options.
|
| 254 |
+
name (str): Name of this protocol message type.
|
| 255 |
+
full_name (str): Fully-qualified name of this protocol message type, which
|
| 256 |
+
will include protocol "package" name and the name of any enclosing
|
| 257 |
+
types.
|
| 258 |
+
containing_type: if provided, this is a nested descriptor, with this
|
| 259 |
+
descriptor as parent, otherwise None.
|
| 260 |
+
serialized_start: The start index (inclusive) in block in the
|
| 261 |
+
file.serialized_pb that describes this descriptor.
|
| 262 |
+
serialized_end: The end index (exclusive) in block in the
|
| 263 |
+
file.serialized_pb that describes this descriptor.
|
| 264 |
+
serialized_options: Protocol message serialized options or None.
|
| 265 |
+
"""
|
| 266 |
+
super(_NestedDescriptorBase, self).__init__(
|
| 267 |
+
file, options, serialized_options, options_class_name
|
| 268 |
+
)
|
| 269 |
+
|
| 270 |
+
self.name = name
|
| 271 |
+
# TODO: Add function to calculate full_name instead of having it in
|
| 272 |
+
# memory?
|
| 273 |
+
self.full_name = full_name
|
| 274 |
+
self.containing_type = containing_type
|
| 275 |
+
|
| 276 |
+
self._serialized_start = serialized_start
|
| 277 |
+
self._serialized_end = serialized_end
|
| 278 |
+
|
| 279 |
+
def CopyToProto(self, proto):
|
| 280 |
+
"""Copies this to the matching proto in descriptor_pb2.
|
| 281 |
+
|
| 282 |
+
Args:
|
| 283 |
+
proto: An empty proto instance from descriptor_pb2.
|
| 284 |
+
|
| 285 |
+
Raises:
|
| 286 |
+
Error: If self couldn't be serialized, due to to few constructor
|
| 287 |
+
arguments.
|
| 288 |
+
"""
|
| 289 |
+
if (self.file is not None and
|
| 290 |
+
self._serialized_start is not None and
|
| 291 |
+
self._serialized_end is not None):
|
| 292 |
+
proto.ParseFromString(self.file.serialized_pb[
|
| 293 |
+
self._serialized_start:self._serialized_end])
|
| 294 |
+
else:
|
| 295 |
+
raise Error('Descriptor does not contain serialization.')
|
| 296 |
+
|
| 297 |
+
|
| 298 |
+
class Descriptor(_NestedDescriptorBase):
|
| 299 |
+
|
| 300 |
+
"""Descriptor for a protocol message type.
|
| 301 |
+
|
| 302 |
+
Attributes:
|
| 303 |
+
name (str): Name of this protocol message type.
|
| 304 |
+
full_name (str): Fully-qualified name of this protocol message type,
|
| 305 |
+
which will include protocol "package" name and the name of any
|
| 306 |
+
enclosing types.
|
| 307 |
+
containing_type (Descriptor): Reference to the descriptor of the type
|
| 308 |
+
containing us, or None if this is top-level.
|
| 309 |
+
fields (list[FieldDescriptor]): Field descriptors for all fields in
|
| 310 |
+
this type.
|
| 311 |
+
fields_by_number (dict(int, FieldDescriptor)): Same
|
| 312 |
+
:class:`FieldDescriptor` objects as in :attr:`fields`, but indexed
|
| 313 |
+
by "number" attribute in each FieldDescriptor.
|
| 314 |
+
fields_by_name (dict(str, FieldDescriptor)): Same
|
| 315 |
+
:class:`FieldDescriptor` objects as in :attr:`fields`, but indexed by
|
| 316 |
+
"name" attribute in each :class:`FieldDescriptor`.
|
| 317 |
+
nested_types (list[Descriptor]): Descriptor references
|
| 318 |
+
for all protocol message types nested within this one.
|
| 319 |
+
nested_types_by_name (dict(str, Descriptor)): Same Descriptor
|
| 320 |
+
objects as in :attr:`nested_types`, but indexed by "name" attribute
|
| 321 |
+
in each Descriptor.
|
| 322 |
+
enum_types (list[EnumDescriptor]): :class:`EnumDescriptor` references
|
| 323 |
+
for all enums contained within this type.
|
| 324 |
+
enum_types_by_name (dict(str, EnumDescriptor)): Same
|
| 325 |
+
:class:`EnumDescriptor` objects as in :attr:`enum_types`, but
|
| 326 |
+
indexed by "name" attribute in each EnumDescriptor.
|
| 327 |
+
enum_values_by_name (dict(str, EnumValueDescriptor)): Dict mapping
|
| 328 |
+
from enum value name to :class:`EnumValueDescriptor` for that value.
|
| 329 |
+
extensions (list[FieldDescriptor]): All extensions defined directly
|
| 330 |
+
within this message type (NOT within a nested type).
|
| 331 |
+
extensions_by_name (dict(str, FieldDescriptor)): Same FieldDescriptor
|
| 332 |
+
objects as :attr:`extensions`, but indexed by "name" attribute of each
|
| 333 |
+
FieldDescriptor.
|
| 334 |
+
is_extendable (bool): Does this type define any extension ranges?
|
| 335 |
+
oneofs (list[OneofDescriptor]): The list of descriptors for oneof fields
|
| 336 |
+
in this message.
|
| 337 |
+
oneofs_by_name (dict(str, OneofDescriptor)): Same objects as in
|
| 338 |
+
:attr:`oneofs`, but indexed by "name" attribute.
|
| 339 |
+
file (FileDescriptor): Reference to file descriptor.
|
| 340 |
+
is_map_entry: If the message type is a map entry.
|
| 341 |
+
|
| 342 |
+
"""
|
| 343 |
+
|
| 344 |
+
if _USE_C_DESCRIPTORS:
|
| 345 |
+
_C_DESCRIPTOR_CLASS = _message.Descriptor
|
| 346 |
+
|
| 347 |
+
def __new__(
|
| 348 |
+
cls,
|
| 349 |
+
name=None,
|
| 350 |
+
full_name=None,
|
| 351 |
+
filename=None,
|
| 352 |
+
containing_type=None,
|
| 353 |
+
fields=None,
|
| 354 |
+
nested_types=None,
|
| 355 |
+
enum_types=None,
|
| 356 |
+
extensions=None,
|
| 357 |
+
options=None,
|
| 358 |
+
serialized_options=None,
|
| 359 |
+
is_extendable=True,
|
| 360 |
+
extension_ranges=None,
|
| 361 |
+
oneofs=None,
|
| 362 |
+
file=None, # pylint: disable=redefined-builtin
|
| 363 |
+
serialized_start=None,
|
| 364 |
+
serialized_end=None,
|
| 365 |
+
syntax=None,
|
| 366 |
+
is_map_entry=False,
|
| 367 |
+
create_key=None):
|
| 368 |
+
_message.Message._CheckCalledFromGeneratedFile()
|
| 369 |
+
return _message.default_pool.FindMessageTypeByName(full_name)
|
| 370 |
+
|
| 371 |
+
# NOTE: The file argument redefining a builtin is nothing we can
|
| 372 |
+
# fix right now since we don't know how many clients already rely on the
|
| 373 |
+
# name of the argument.
|
| 374 |
+
def __init__(self, name, full_name, filename, containing_type, fields,
|
| 375 |
+
nested_types, enum_types, extensions, options=None,
|
| 376 |
+
serialized_options=None,
|
| 377 |
+
is_extendable=True, extension_ranges=None, oneofs=None,
|
| 378 |
+
file=None, serialized_start=None, serialized_end=None, # pylint: disable=redefined-builtin
|
| 379 |
+
syntax=None, is_map_entry=False, create_key=None):
|
| 380 |
+
"""Arguments to __init__() are as described in the description
|
| 381 |
+
of Descriptor fields above.
|
| 382 |
+
|
| 383 |
+
Note that filename is an obsolete argument, that is not used anymore.
|
| 384 |
+
Please use file.name to access this as an attribute.
|
| 385 |
+
"""
|
| 386 |
+
if create_key is not _internal_create_key:
|
| 387 |
+
_Deprecated('Descriptor')
|
| 388 |
+
|
| 389 |
+
super(Descriptor, self).__init__(
|
| 390 |
+
options, 'MessageOptions', name, full_name, file,
|
| 391 |
+
containing_type, serialized_start=serialized_start,
|
| 392 |
+
serialized_end=serialized_end, serialized_options=serialized_options)
|
| 393 |
+
|
| 394 |
+
# We have fields in addition to fields_by_name and fields_by_number,
|
| 395 |
+
# so that:
|
| 396 |
+
# 1. Clients can index fields by "order in which they're listed."
|
| 397 |
+
# 2. Clients can easily iterate over all fields with the terse
|
| 398 |
+
# syntax: for f in descriptor.fields: ...
|
| 399 |
+
self.fields = fields
|
| 400 |
+
for field in self.fields:
|
| 401 |
+
field.containing_type = self
|
| 402 |
+
field.file = file
|
| 403 |
+
self.fields_by_number = dict((f.number, f) for f in fields)
|
| 404 |
+
self.fields_by_name = dict((f.name, f) for f in fields)
|
| 405 |
+
self._fields_by_camelcase_name = None
|
| 406 |
+
|
| 407 |
+
self.nested_types = nested_types
|
| 408 |
+
for nested_type in nested_types:
|
| 409 |
+
nested_type.containing_type = self
|
| 410 |
+
self.nested_types_by_name = dict((t.name, t) for t in nested_types)
|
| 411 |
+
|
| 412 |
+
self.enum_types = enum_types
|
| 413 |
+
for enum_type in self.enum_types:
|
| 414 |
+
enum_type.containing_type = self
|
| 415 |
+
self.enum_types_by_name = dict((t.name, t) for t in enum_types)
|
| 416 |
+
self.enum_values_by_name = dict(
|
| 417 |
+
(v.name, v) for t in enum_types for v in t.values)
|
| 418 |
+
|
| 419 |
+
self.extensions = extensions
|
| 420 |
+
for extension in self.extensions:
|
| 421 |
+
extension.extension_scope = self
|
| 422 |
+
self.extensions_by_name = dict((f.name, f) for f in extensions)
|
| 423 |
+
self.is_extendable = is_extendable
|
| 424 |
+
self.extension_ranges = extension_ranges
|
| 425 |
+
self.oneofs = oneofs if oneofs is not None else []
|
| 426 |
+
self.oneofs_by_name = dict((o.name, o) for o in self.oneofs)
|
| 427 |
+
for oneof in self.oneofs:
|
| 428 |
+
oneof.containing_type = self
|
| 429 |
+
oneof.file = file
|
| 430 |
+
self._is_map_entry = is_map_entry
|
| 431 |
+
|
| 432 |
+
@property
|
| 433 |
+
def _parent(self):
|
| 434 |
+
return self.containing_type or self.file
|
| 435 |
+
|
| 436 |
+
@property
|
| 437 |
+
def fields_by_camelcase_name(self):
|
| 438 |
+
"""Same FieldDescriptor objects as in :attr:`fields`, but indexed by
|
| 439 |
+
:attr:`FieldDescriptor.camelcase_name`.
|
| 440 |
+
"""
|
| 441 |
+
if self._fields_by_camelcase_name is None:
|
| 442 |
+
self._fields_by_camelcase_name = dict(
|
| 443 |
+
(f.camelcase_name, f) for f in self.fields)
|
| 444 |
+
return self._fields_by_camelcase_name
|
| 445 |
+
|
| 446 |
+
def EnumValueName(self, enum, value):
|
| 447 |
+
"""Returns the string name of an enum value.
|
| 448 |
+
|
| 449 |
+
This is just a small helper method to simplify a common operation.
|
| 450 |
+
|
| 451 |
+
Args:
|
| 452 |
+
enum: string name of the Enum.
|
| 453 |
+
value: int, value of the enum.
|
| 454 |
+
|
| 455 |
+
Returns:
|
| 456 |
+
string name of the enum value.
|
| 457 |
+
|
| 458 |
+
Raises:
|
| 459 |
+
KeyError if either the Enum doesn't exist or the value is not a valid
|
| 460 |
+
value for the enum.
|
| 461 |
+
"""
|
| 462 |
+
return self.enum_types_by_name[enum].values_by_number[value].name
|
| 463 |
+
|
| 464 |
+
def CopyToProto(self, proto):
|
| 465 |
+
"""Copies this to a descriptor_pb2.DescriptorProto.
|
| 466 |
+
|
| 467 |
+
Args:
|
| 468 |
+
proto: An empty descriptor_pb2.DescriptorProto.
|
| 469 |
+
"""
|
| 470 |
+
# This function is overridden to give a better doc comment.
|
| 471 |
+
super(Descriptor, self).CopyToProto(proto)
|
| 472 |
+
|
| 473 |
+
|
| 474 |
+
# TODO: We should have aggressive checking here,
|
| 475 |
+
# for example:
|
| 476 |
+
# * If you specify a repeated field, you should not be allowed
|
| 477 |
+
# to specify a default value.
|
| 478 |
+
# * [Other examples here as needed].
|
| 479 |
+
#
|
| 480 |
+
# TODO: for this and other *Descriptor classes, we
|
| 481 |
+
# might also want to lock things down aggressively (e.g.,
|
| 482 |
+
# prevent clients from setting the attributes). Having
|
| 483 |
+
# stronger invariants here in general will reduce the number
|
| 484 |
+
# of runtime checks we must do in reflection.py...
|
| 485 |
+
class FieldDescriptor(DescriptorBase):
|
| 486 |
+
|
| 487 |
+
"""Descriptor for a single field in a .proto file.
|
| 488 |
+
|
| 489 |
+
Attributes:
|
| 490 |
+
name (str): Name of this field, exactly as it appears in .proto.
|
| 491 |
+
full_name (str): Name of this field, including containing scope. This is
|
| 492 |
+
particularly relevant for extensions.
|
| 493 |
+
index (int): Dense, 0-indexed index giving the order that this
|
| 494 |
+
field textually appears within its message in the .proto file.
|
| 495 |
+
number (int): Tag number declared for this field in the .proto file.
|
| 496 |
+
|
| 497 |
+
type (int): (One of the TYPE_* constants below) Declared type.
|
| 498 |
+
cpp_type (int): (One of the CPPTYPE_* constants below) C++ type used to
|
| 499 |
+
represent this field.
|
| 500 |
+
|
| 501 |
+
label (int): (One of the LABEL_* constants below) Tells whether this
|
| 502 |
+
field is optional, required, or repeated.
|
| 503 |
+
has_default_value (bool): True if this field has a default value defined,
|
| 504 |
+
otherwise false.
|
| 505 |
+
default_value (Varies): Default value of this field. Only
|
| 506 |
+
meaningful for non-repeated scalar fields. Repeated fields
|
| 507 |
+
should always set this to [], and non-repeated composite
|
| 508 |
+
fields should always set this to None.
|
| 509 |
+
|
| 510 |
+
containing_type (Descriptor): Descriptor of the protocol message
|
| 511 |
+
type that contains this field. Set by the Descriptor constructor
|
| 512 |
+
if we're passed into one.
|
| 513 |
+
Somewhat confusingly, for extension fields, this is the
|
| 514 |
+
descriptor of the EXTENDED message, not the descriptor
|
| 515 |
+
of the message containing this field. (See is_extension and
|
| 516 |
+
extension_scope below).
|
| 517 |
+
message_type (Descriptor): If a composite field, a descriptor
|
| 518 |
+
of the message type contained in this field. Otherwise, this is None.
|
| 519 |
+
enum_type (EnumDescriptor): If this field contains an enum, a
|
| 520 |
+
descriptor of that enum. Otherwise, this is None.
|
| 521 |
+
|
| 522 |
+
is_extension: True iff this describes an extension field.
|
| 523 |
+
extension_scope (Descriptor): Only meaningful if is_extension is True.
|
| 524 |
+
Gives the message that immediately contains this extension field.
|
| 525 |
+
Will be None iff we're a top-level (file-level) extension field.
|
| 526 |
+
|
| 527 |
+
options (descriptor_pb2.FieldOptions): Protocol message field options or
|
| 528 |
+
None to use default field options.
|
| 529 |
+
|
| 530 |
+
containing_oneof (OneofDescriptor): If the field is a member of a oneof
|
| 531 |
+
union, contains its descriptor. Otherwise, None.
|
| 532 |
+
|
| 533 |
+
file (FileDescriptor): Reference to file descriptor.
|
| 534 |
+
"""
|
| 535 |
+
|
| 536 |
+
# Must be consistent with C++ FieldDescriptor::Type enum in
|
| 537 |
+
# descriptor.h.
|
| 538 |
+
#
|
| 539 |
+
# TODO: Find a way to eliminate this repetition.
|
| 540 |
+
TYPE_DOUBLE = 1
|
| 541 |
+
TYPE_FLOAT = 2
|
| 542 |
+
TYPE_INT64 = 3
|
| 543 |
+
TYPE_UINT64 = 4
|
| 544 |
+
TYPE_INT32 = 5
|
| 545 |
+
TYPE_FIXED64 = 6
|
| 546 |
+
TYPE_FIXED32 = 7
|
| 547 |
+
TYPE_BOOL = 8
|
| 548 |
+
TYPE_STRING = 9
|
| 549 |
+
TYPE_GROUP = 10
|
| 550 |
+
TYPE_MESSAGE = 11
|
| 551 |
+
TYPE_BYTES = 12
|
| 552 |
+
TYPE_UINT32 = 13
|
| 553 |
+
TYPE_ENUM = 14
|
| 554 |
+
TYPE_SFIXED32 = 15
|
| 555 |
+
TYPE_SFIXED64 = 16
|
| 556 |
+
TYPE_SINT32 = 17
|
| 557 |
+
TYPE_SINT64 = 18
|
| 558 |
+
MAX_TYPE = 18
|
| 559 |
+
|
| 560 |
+
# Must be consistent with C++ FieldDescriptor::CppType enum in
|
| 561 |
+
# descriptor.h.
|
| 562 |
+
#
|
| 563 |
+
# TODO: Find a way to eliminate this repetition.
|
| 564 |
+
CPPTYPE_INT32 = 1
|
| 565 |
+
CPPTYPE_INT64 = 2
|
| 566 |
+
CPPTYPE_UINT32 = 3
|
| 567 |
+
CPPTYPE_UINT64 = 4
|
| 568 |
+
CPPTYPE_DOUBLE = 5
|
| 569 |
+
CPPTYPE_FLOAT = 6
|
| 570 |
+
CPPTYPE_BOOL = 7
|
| 571 |
+
CPPTYPE_ENUM = 8
|
| 572 |
+
CPPTYPE_STRING = 9
|
| 573 |
+
CPPTYPE_MESSAGE = 10
|
| 574 |
+
MAX_CPPTYPE = 10
|
| 575 |
+
|
| 576 |
+
_PYTHON_TO_CPP_PROTO_TYPE_MAP = {
|
| 577 |
+
TYPE_DOUBLE: CPPTYPE_DOUBLE,
|
| 578 |
+
TYPE_FLOAT: CPPTYPE_FLOAT,
|
| 579 |
+
TYPE_ENUM: CPPTYPE_ENUM,
|
| 580 |
+
TYPE_INT64: CPPTYPE_INT64,
|
| 581 |
+
TYPE_SINT64: CPPTYPE_INT64,
|
| 582 |
+
TYPE_SFIXED64: CPPTYPE_INT64,
|
| 583 |
+
TYPE_UINT64: CPPTYPE_UINT64,
|
| 584 |
+
TYPE_FIXED64: CPPTYPE_UINT64,
|
| 585 |
+
TYPE_INT32: CPPTYPE_INT32,
|
| 586 |
+
TYPE_SFIXED32: CPPTYPE_INT32,
|
| 587 |
+
TYPE_SINT32: CPPTYPE_INT32,
|
| 588 |
+
TYPE_UINT32: CPPTYPE_UINT32,
|
| 589 |
+
TYPE_FIXED32: CPPTYPE_UINT32,
|
| 590 |
+
TYPE_BYTES: CPPTYPE_STRING,
|
| 591 |
+
TYPE_STRING: CPPTYPE_STRING,
|
| 592 |
+
TYPE_BOOL: CPPTYPE_BOOL,
|
| 593 |
+
TYPE_MESSAGE: CPPTYPE_MESSAGE,
|
| 594 |
+
TYPE_GROUP: CPPTYPE_MESSAGE
|
| 595 |
+
}
|
| 596 |
+
|
| 597 |
+
# Must be consistent with C++ FieldDescriptor::Label enum in
|
| 598 |
+
# descriptor.h.
|
| 599 |
+
#
|
| 600 |
+
# TODO: Find a way to eliminate this repetition.
|
| 601 |
+
LABEL_OPTIONAL = 1
|
| 602 |
+
LABEL_REQUIRED = 2
|
| 603 |
+
LABEL_REPEATED = 3
|
| 604 |
+
MAX_LABEL = 3
|
| 605 |
+
|
| 606 |
+
# Must be consistent with C++ constants kMaxNumber, kFirstReservedNumber,
|
| 607 |
+
# and kLastReservedNumber in descriptor.h
|
| 608 |
+
MAX_FIELD_NUMBER = (1 << 29) - 1
|
| 609 |
+
FIRST_RESERVED_FIELD_NUMBER = 19000
|
| 610 |
+
LAST_RESERVED_FIELD_NUMBER = 19999
|
| 611 |
+
|
| 612 |
+
if _USE_C_DESCRIPTORS:
|
| 613 |
+
_C_DESCRIPTOR_CLASS = _message.FieldDescriptor
|
| 614 |
+
|
| 615 |
+
def __new__(cls, name, full_name, index, number, type, cpp_type, label,
|
| 616 |
+
default_value, message_type, enum_type, containing_type,
|
| 617 |
+
is_extension, extension_scope, options=None,
|
| 618 |
+
serialized_options=None,
|
| 619 |
+
has_default_value=True, containing_oneof=None, json_name=None,
|
| 620 |
+
file=None, create_key=None): # pylint: disable=redefined-builtin
|
| 621 |
+
_message.Message._CheckCalledFromGeneratedFile()
|
| 622 |
+
if is_extension:
|
| 623 |
+
return _message.default_pool.FindExtensionByName(full_name)
|
| 624 |
+
else:
|
| 625 |
+
return _message.default_pool.FindFieldByName(full_name)
|
| 626 |
+
|
| 627 |
+
def __init__(self, name, full_name, index, number, type, cpp_type, label,
|
| 628 |
+
default_value, message_type, enum_type, containing_type,
|
| 629 |
+
is_extension, extension_scope, options=None,
|
| 630 |
+
serialized_options=None,
|
| 631 |
+
has_default_value=True, containing_oneof=None, json_name=None,
|
| 632 |
+
file=None, create_key=None): # pylint: disable=redefined-builtin
|
| 633 |
+
"""The arguments are as described in the description of FieldDescriptor
|
| 634 |
+
attributes above.
|
| 635 |
+
|
| 636 |
+
Note that containing_type may be None, and may be set later if necessary
|
| 637 |
+
(to deal with circular references between message types, for example).
|
| 638 |
+
Likewise for extension_scope.
|
| 639 |
+
"""
|
| 640 |
+
if create_key is not _internal_create_key:
|
| 641 |
+
_Deprecated('FieldDescriptor')
|
| 642 |
+
|
| 643 |
+
super(FieldDescriptor, self).__init__(
|
| 644 |
+
file, options, serialized_options, 'FieldOptions'
|
| 645 |
+
)
|
| 646 |
+
self.name = name
|
| 647 |
+
self.full_name = full_name
|
| 648 |
+
self._camelcase_name = None
|
| 649 |
+
if json_name is None:
|
| 650 |
+
self.json_name = _ToJsonName(name)
|
| 651 |
+
else:
|
| 652 |
+
self.json_name = json_name
|
| 653 |
+
self.index = index
|
| 654 |
+
self.number = number
|
| 655 |
+
self._type = type
|
| 656 |
+
self.cpp_type = cpp_type
|
| 657 |
+
self._label = label
|
| 658 |
+
self.has_default_value = has_default_value
|
| 659 |
+
self.default_value = default_value
|
| 660 |
+
self.containing_type = containing_type
|
| 661 |
+
self.message_type = message_type
|
| 662 |
+
self.enum_type = enum_type
|
| 663 |
+
self.is_extension = is_extension
|
| 664 |
+
self.extension_scope = extension_scope
|
| 665 |
+
self.containing_oneof = containing_oneof
|
| 666 |
+
if api_implementation.Type() == 'python':
|
| 667 |
+
self._cdescriptor = None
|
| 668 |
+
else:
|
| 669 |
+
if is_extension:
|
| 670 |
+
self._cdescriptor = _message.default_pool.FindExtensionByName(full_name)
|
| 671 |
+
else:
|
| 672 |
+
self._cdescriptor = _message.default_pool.FindFieldByName(full_name)
|
| 673 |
+
|
| 674 |
+
@property
|
| 675 |
+
def _parent(self):
|
| 676 |
+
if self.containing_oneof:
|
| 677 |
+
return self.containing_oneof
|
| 678 |
+
if self.is_extension:
|
| 679 |
+
return self.extension_scope or self.file
|
| 680 |
+
return self.containing_type
|
| 681 |
+
|
| 682 |
+
def _InferLegacyFeatures(self, edition, options, features):
|
| 683 |
+
# pylint: disable=g-import-not-at-top
|
| 684 |
+
from google.protobuf import descriptor_pb2
|
| 685 |
+
|
| 686 |
+
if edition >= descriptor_pb2.Edition.EDITION_2023:
|
| 687 |
+
return
|
| 688 |
+
|
| 689 |
+
if self._label == FieldDescriptor.LABEL_REQUIRED:
|
| 690 |
+
features.field_presence = (
|
| 691 |
+
descriptor_pb2.FeatureSet.FieldPresence.LEGACY_REQUIRED
|
| 692 |
+
)
|
| 693 |
+
|
| 694 |
+
if self._type == FieldDescriptor.TYPE_GROUP:
|
| 695 |
+
features.message_encoding = (
|
| 696 |
+
descriptor_pb2.FeatureSet.MessageEncoding.DELIMITED
|
| 697 |
+
)
|
| 698 |
+
|
| 699 |
+
if options.HasField('packed'):
|
| 700 |
+
features.repeated_field_encoding = (
|
| 701 |
+
descriptor_pb2.FeatureSet.RepeatedFieldEncoding.PACKED
|
| 702 |
+
if options.packed
|
| 703 |
+
else descriptor_pb2.FeatureSet.RepeatedFieldEncoding.EXPANDED
|
| 704 |
+
)
|
| 705 |
+
|
| 706 |
+
@property
|
| 707 |
+
def type(self):
|
| 708 |
+
if (
|
| 709 |
+
self._GetFeatures().message_encoding
|
| 710 |
+
== _FEATURESET_MESSAGE_ENCODING_DELIMITED
|
| 711 |
+
and self.message_type
|
| 712 |
+
and not self.message_type.GetOptions().map_entry
|
| 713 |
+
and not self.containing_type.GetOptions().map_entry
|
| 714 |
+
):
|
| 715 |
+
return FieldDescriptor.TYPE_GROUP
|
| 716 |
+
return self._type
|
| 717 |
+
|
| 718 |
+
@type.setter
|
| 719 |
+
def type(self, val):
|
| 720 |
+
self._type = val
|
| 721 |
+
|
| 722 |
+
@property
|
| 723 |
+
def label(self):
|
| 724 |
+
if (
|
| 725 |
+
self._GetFeatures().field_presence
|
| 726 |
+
== _FEATURESET_FIELD_PRESENCE_LEGACY_REQUIRED
|
| 727 |
+
):
|
| 728 |
+
return FieldDescriptor.LABEL_REQUIRED
|
| 729 |
+
return self._label
|
| 730 |
+
|
| 731 |
+
@property
|
| 732 |
+
def camelcase_name(self):
|
| 733 |
+
"""Camelcase name of this field.
|
| 734 |
+
|
| 735 |
+
Returns:
|
| 736 |
+
str: the name in CamelCase.
|
| 737 |
+
"""
|
| 738 |
+
if self._camelcase_name is None:
|
| 739 |
+
self._camelcase_name = _ToCamelCase(self.name)
|
| 740 |
+
return self._camelcase_name
|
| 741 |
+
|
| 742 |
+
@property
|
| 743 |
+
def has_presence(self):
|
| 744 |
+
"""Whether the field distinguishes between unpopulated and default values.
|
| 745 |
+
|
| 746 |
+
Raises:
|
| 747 |
+
RuntimeError: singular field that is not linked with message nor file.
|
| 748 |
+
"""
|
| 749 |
+
if self.label == FieldDescriptor.LABEL_REPEATED:
|
| 750 |
+
return False
|
| 751 |
+
if (
|
| 752 |
+
self.cpp_type == FieldDescriptor.CPPTYPE_MESSAGE
|
| 753 |
+
or self.is_extension
|
| 754 |
+
or self.containing_oneof
|
| 755 |
+
):
|
| 756 |
+
return True
|
| 757 |
+
|
| 758 |
+
return (
|
| 759 |
+
self._GetFeatures().field_presence
|
| 760 |
+
!= _FEATURESET_FIELD_PRESENCE_IMPLICIT
|
| 761 |
+
)
|
| 762 |
+
|
| 763 |
+
@property
|
| 764 |
+
def is_packed(self):
|
| 765 |
+
"""Returns if the field is packed."""
|
| 766 |
+
if self.label != FieldDescriptor.LABEL_REPEATED:
|
| 767 |
+
return False
|
| 768 |
+
field_type = self.type
|
| 769 |
+
if (field_type == FieldDescriptor.TYPE_STRING or
|
| 770 |
+
field_type == FieldDescriptor.TYPE_GROUP or
|
| 771 |
+
field_type == FieldDescriptor.TYPE_MESSAGE or
|
| 772 |
+
field_type == FieldDescriptor.TYPE_BYTES):
|
| 773 |
+
return False
|
| 774 |
+
|
| 775 |
+
return (
|
| 776 |
+
self._GetFeatures().repeated_field_encoding
|
| 777 |
+
== _FEATURESET_REPEATED_FIELD_ENCODING_PACKED
|
| 778 |
+
)
|
| 779 |
+
|
| 780 |
+
@staticmethod
|
| 781 |
+
def ProtoTypeToCppProtoType(proto_type):
|
| 782 |
+
"""Converts from a Python proto type to a C++ Proto Type.
|
| 783 |
+
|
| 784 |
+
The Python ProtocolBuffer classes specify both the 'Python' datatype and the
|
| 785 |
+
'C++' datatype - and they're not the same. This helper method should
|
| 786 |
+
translate from one to another.
|
| 787 |
+
|
| 788 |
+
Args:
|
| 789 |
+
proto_type: the Python proto type (descriptor.FieldDescriptor.TYPE_*)
|
| 790 |
+
Returns:
|
| 791 |
+
int: descriptor.FieldDescriptor.CPPTYPE_*, the C++ type.
|
| 792 |
+
Raises:
|
| 793 |
+
TypeTransformationError: when the Python proto type isn't known.
|
| 794 |
+
"""
|
| 795 |
+
try:
|
| 796 |
+
return FieldDescriptor._PYTHON_TO_CPP_PROTO_TYPE_MAP[proto_type]
|
| 797 |
+
except KeyError:
|
| 798 |
+
raise TypeTransformationError('Unknown proto_type: %s' % proto_type)
|
| 799 |
+
|
| 800 |
+
|
| 801 |
+
class EnumDescriptor(_NestedDescriptorBase):
|
| 802 |
+
|
| 803 |
+
"""Descriptor for an enum defined in a .proto file.
|
| 804 |
+
|
| 805 |
+
Attributes:
|
| 806 |
+
name (str): Name of the enum type.
|
| 807 |
+
full_name (str): Full name of the type, including package name
|
| 808 |
+
and any enclosing type(s).
|
| 809 |
+
|
| 810 |
+
values (list[EnumValueDescriptor]): List of the values
|
| 811 |
+
in this enum.
|
| 812 |
+
values_by_name (dict(str, EnumValueDescriptor)): Same as :attr:`values`,
|
| 813 |
+
but indexed by the "name" field of each EnumValueDescriptor.
|
| 814 |
+
values_by_number (dict(int, EnumValueDescriptor)): Same as :attr:`values`,
|
| 815 |
+
but indexed by the "number" field of each EnumValueDescriptor.
|
| 816 |
+
containing_type (Descriptor): Descriptor of the immediate containing
|
| 817 |
+
type of this enum, or None if this is an enum defined at the
|
| 818 |
+
top level in a .proto file. Set by Descriptor's constructor
|
| 819 |
+
if we're passed into one.
|
| 820 |
+
file (FileDescriptor): Reference to file descriptor.
|
| 821 |
+
options (descriptor_pb2.EnumOptions): Enum options message or
|
| 822 |
+
None to use default enum options.
|
| 823 |
+
"""
|
| 824 |
+
|
| 825 |
+
if _USE_C_DESCRIPTORS:
|
| 826 |
+
_C_DESCRIPTOR_CLASS = _message.EnumDescriptor
|
| 827 |
+
|
| 828 |
+
def __new__(cls, name, full_name, filename, values,
|
| 829 |
+
containing_type=None, options=None,
|
| 830 |
+
serialized_options=None, file=None, # pylint: disable=redefined-builtin
|
| 831 |
+
serialized_start=None, serialized_end=None, create_key=None):
|
| 832 |
+
_message.Message._CheckCalledFromGeneratedFile()
|
| 833 |
+
return _message.default_pool.FindEnumTypeByName(full_name)
|
| 834 |
+
|
| 835 |
+
def __init__(self, name, full_name, filename, values,
|
| 836 |
+
containing_type=None, options=None,
|
| 837 |
+
serialized_options=None, file=None, # pylint: disable=redefined-builtin
|
| 838 |
+
serialized_start=None, serialized_end=None, create_key=None):
|
| 839 |
+
"""Arguments are as described in the attribute description above.
|
| 840 |
+
|
| 841 |
+
Note that filename is an obsolete argument, that is not used anymore.
|
| 842 |
+
Please use file.name to access this as an attribute.
|
| 843 |
+
"""
|
| 844 |
+
if create_key is not _internal_create_key:
|
| 845 |
+
_Deprecated('EnumDescriptor')
|
| 846 |
+
|
| 847 |
+
super(EnumDescriptor, self).__init__(
|
| 848 |
+
options, 'EnumOptions', name, full_name, file,
|
| 849 |
+
containing_type, serialized_start=serialized_start,
|
| 850 |
+
serialized_end=serialized_end, serialized_options=serialized_options)
|
| 851 |
+
|
| 852 |
+
self.values = values
|
| 853 |
+
for value in self.values:
|
| 854 |
+
value.file = file
|
| 855 |
+
value.type = self
|
| 856 |
+
self.values_by_name = dict((v.name, v) for v in values)
|
| 857 |
+
# Values are reversed to ensure that the first alias is retained.
|
| 858 |
+
self.values_by_number = dict((v.number, v) for v in reversed(values))
|
| 859 |
+
|
| 860 |
+
@property
|
| 861 |
+
def _parent(self):
|
| 862 |
+
return self.containing_type or self.file
|
| 863 |
+
|
| 864 |
+
@property
|
| 865 |
+
def is_closed(self):
|
| 866 |
+
"""Returns true whether this is a "closed" enum.
|
| 867 |
+
|
| 868 |
+
This means that it:
|
| 869 |
+
- Has a fixed set of values, rather than being equivalent to an int32.
|
| 870 |
+
- Encountering values not in this set causes them to be treated as unknown
|
| 871 |
+
fields.
|
| 872 |
+
- The first value (i.e., the default) may be nonzero.
|
| 873 |
+
|
| 874 |
+
WARNING: Some runtimes currently have a quirk where non-closed enums are
|
| 875 |
+
treated as closed when used as the type of fields defined in a
|
| 876 |
+
`syntax = proto2;` file. This quirk is not present in all runtimes; as of
|
| 877 |
+
writing, we know that:
|
| 878 |
+
|
| 879 |
+
- C++, Java, and C++-based Python share this quirk.
|
| 880 |
+
- UPB and UPB-based Python do not.
|
| 881 |
+
- PHP and Ruby treat all enums as open regardless of declaration.
|
| 882 |
+
|
| 883 |
+
Care should be taken when using this function to respect the target
|
| 884 |
+
runtime's enum handling quirks.
|
| 885 |
+
"""
|
| 886 |
+
return self._GetFeatures().enum_type == _FEATURESET_ENUM_TYPE_CLOSED
|
| 887 |
+
|
| 888 |
+
def CopyToProto(self, proto):
|
| 889 |
+
"""Copies this to a descriptor_pb2.EnumDescriptorProto.
|
| 890 |
+
|
| 891 |
+
Args:
|
| 892 |
+
proto (descriptor_pb2.EnumDescriptorProto): An empty descriptor proto.
|
| 893 |
+
"""
|
| 894 |
+
# This function is overridden to give a better doc comment.
|
| 895 |
+
super(EnumDescriptor, self).CopyToProto(proto)
|
| 896 |
+
|
| 897 |
+
|
| 898 |
+
class EnumValueDescriptor(DescriptorBase):
|
| 899 |
+
|
| 900 |
+
"""Descriptor for a single value within an enum.
|
| 901 |
+
|
| 902 |
+
Attributes:
|
| 903 |
+
name (str): Name of this value.
|
| 904 |
+
index (int): Dense, 0-indexed index giving the order that this
|
| 905 |
+
value appears textually within its enum in the .proto file.
|
| 906 |
+
number (int): Actual number assigned to this enum value.
|
| 907 |
+
type (EnumDescriptor): :class:`EnumDescriptor` to which this value
|
| 908 |
+
belongs. Set by :class:`EnumDescriptor`'s constructor if we're
|
| 909 |
+
passed into one.
|
| 910 |
+
options (descriptor_pb2.EnumValueOptions): Enum value options message or
|
| 911 |
+
None to use default enum value options options.
|
| 912 |
+
"""
|
| 913 |
+
|
| 914 |
+
if _USE_C_DESCRIPTORS:
|
| 915 |
+
_C_DESCRIPTOR_CLASS = _message.EnumValueDescriptor
|
| 916 |
+
|
| 917 |
+
def __new__(cls, name, index, number,
|
| 918 |
+
type=None, # pylint: disable=redefined-builtin
|
| 919 |
+
options=None, serialized_options=None, create_key=None):
|
| 920 |
+
_message.Message._CheckCalledFromGeneratedFile()
|
| 921 |
+
# There is no way we can build a complete EnumValueDescriptor with the
|
| 922 |
+
# given parameters (the name of the Enum is not known, for example).
|
| 923 |
+
# Fortunately generated files just pass it to the EnumDescriptor()
|
| 924 |
+
# constructor, which will ignore it, so returning None is good enough.
|
| 925 |
+
return None
|
| 926 |
+
|
| 927 |
+
def __init__(self, name, index, number,
|
| 928 |
+
type=None, # pylint: disable=redefined-builtin
|
| 929 |
+
options=None, serialized_options=None, create_key=None):
|
| 930 |
+
"""Arguments are as described in the attribute description above."""
|
| 931 |
+
if create_key is not _internal_create_key:
|
| 932 |
+
_Deprecated('EnumValueDescriptor')
|
| 933 |
+
|
| 934 |
+
super(EnumValueDescriptor, self).__init__(
|
| 935 |
+
type.file if type else None,
|
| 936 |
+
options,
|
| 937 |
+
serialized_options,
|
| 938 |
+
'EnumValueOptions',
|
| 939 |
+
)
|
| 940 |
+
self.name = name
|
| 941 |
+
self.index = index
|
| 942 |
+
self.number = number
|
| 943 |
+
self.type = type
|
| 944 |
+
|
| 945 |
+
@property
|
| 946 |
+
def _parent(self):
|
| 947 |
+
return self.type
|
| 948 |
+
|
| 949 |
+
|
| 950 |
+
class OneofDescriptor(DescriptorBase):
|
| 951 |
+
"""Descriptor for a oneof field.
|
| 952 |
+
|
| 953 |
+
Attributes:
|
| 954 |
+
name (str): Name of the oneof field.
|
| 955 |
+
full_name (str): Full name of the oneof field, including package name.
|
| 956 |
+
index (int): 0-based index giving the order of the oneof field inside
|
| 957 |
+
its containing type.
|
| 958 |
+
containing_type (Descriptor): :class:`Descriptor` of the protocol message
|
| 959 |
+
type that contains this field. Set by the :class:`Descriptor` constructor
|
| 960 |
+
if we're passed into one.
|
| 961 |
+
fields (list[FieldDescriptor]): The list of field descriptors this
|
| 962 |
+
oneof can contain.
|
| 963 |
+
"""
|
| 964 |
+
|
| 965 |
+
if _USE_C_DESCRIPTORS:
|
| 966 |
+
_C_DESCRIPTOR_CLASS = _message.OneofDescriptor
|
| 967 |
+
|
| 968 |
+
def __new__(
|
| 969 |
+
cls, name, full_name, index, containing_type, fields, options=None,
|
| 970 |
+
serialized_options=None, create_key=None):
|
| 971 |
+
_message.Message._CheckCalledFromGeneratedFile()
|
| 972 |
+
return _message.default_pool.FindOneofByName(full_name)
|
| 973 |
+
|
| 974 |
+
def __init__(
|
| 975 |
+
self, name, full_name, index, containing_type, fields, options=None,
|
| 976 |
+
serialized_options=None, create_key=None):
|
| 977 |
+
"""Arguments are as described in the attribute description above."""
|
| 978 |
+
if create_key is not _internal_create_key:
|
| 979 |
+
_Deprecated('OneofDescriptor')
|
| 980 |
+
|
| 981 |
+
super(OneofDescriptor, self).__init__(
|
| 982 |
+
containing_type.file if containing_type else None,
|
| 983 |
+
options,
|
| 984 |
+
serialized_options,
|
| 985 |
+
'OneofOptions',
|
| 986 |
+
)
|
| 987 |
+
self.name = name
|
| 988 |
+
self.full_name = full_name
|
| 989 |
+
self.index = index
|
| 990 |
+
self.containing_type = containing_type
|
| 991 |
+
self.fields = fields
|
| 992 |
+
|
| 993 |
+
@property
|
| 994 |
+
def _parent(self):
|
| 995 |
+
return self.containing_type
|
| 996 |
+
|
| 997 |
+
|
| 998 |
+
class ServiceDescriptor(_NestedDescriptorBase):
|
| 999 |
+
|
| 1000 |
+
"""Descriptor for a service.
|
| 1001 |
+
|
| 1002 |
+
Attributes:
|
| 1003 |
+
name (str): Name of the service.
|
| 1004 |
+
full_name (str): Full name of the service, including package name.
|
| 1005 |
+
index (int): 0-indexed index giving the order that this services
|
| 1006 |
+
definition appears within the .proto file.
|
| 1007 |
+
methods (list[MethodDescriptor]): List of methods provided by this
|
| 1008 |
+
service.
|
| 1009 |
+
methods_by_name (dict(str, MethodDescriptor)): Same
|
| 1010 |
+
:class:`MethodDescriptor` objects as in :attr:`methods_by_name`, but
|
| 1011 |
+
indexed by "name" attribute in each :class:`MethodDescriptor`.
|
| 1012 |
+
options (descriptor_pb2.ServiceOptions): Service options message or
|
| 1013 |
+
None to use default service options.
|
| 1014 |
+
file (FileDescriptor): Reference to file info.
|
| 1015 |
+
"""
|
| 1016 |
+
|
| 1017 |
+
if _USE_C_DESCRIPTORS:
|
| 1018 |
+
_C_DESCRIPTOR_CLASS = _message.ServiceDescriptor
|
| 1019 |
+
|
| 1020 |
+
def __new__(
|
| 1021 |
+
cls,
|
| 1022 |
+
name=None,
|
| 1023 |
+
full_name=None,
|
| 1024 |
+
index=None,
|
| 1025 |
+
methods=None,
|
| 1026 |
+
options=None,
|
| 1027 |
+
serialized_options=None,
|
| 1028 |
+
file=None, # pylint: disable=redefined-builtin
|
| 1029 |
+
serialized_start=None,
|
| 1030 |
+
serialized_end=None,
|
| 1031 |
+
create_key=None):
|
| 1032 |
+
_message.Message._CheckCalledFromGeneratedFile() # pylint: disable=protected-access
|
| 1033 |
+
return _message.default_pool.FindServiceByName(full_name)
|
| 1034 |
+
|
| 1035 |
+
def __init__(self, name, full_name, index, methods, options=None,
|
| 1036 |
+
serialized_options=None, file=None, # pylint: disable=redefined-builtin
|
| 1037 |
+
serialized_start=None, serialized_end=None, create_key=None):
|
| 1038 |
+
if create_key is not _internal_create_key:
|
| 1039 |
+
_Deprecated('ServiceDescriptor')
|
| 1040 |
+
|
| 1041 |
+
super(ServiceDescriptor, self).__init__(
|
| 1042 |
+
options, 'ServiceOptions', name, full_name, file,
|
| 1043 |
+
None, serialized_start=serialized_start,
|
| 1044 |
+
serialized_end=serialized_end, serialized_options=serialized_options)
|
| 1045 |
+
self.index = index
|
| 1046 |
+
self.methods = methods
|
| 1047 |
+
self.methods_by_name = dict((m.name, m) for m in methods)
|
| 1048 |
+
# Set the containing service for each method in this service.
|
| 1049 |
+
for method in self.methods:
|
| 1050 |
+
method.file = self.file
|
| 1051 |
+
method.containing_service = self
|
| 1052 |
+
|
| 1053 |
+
@property
|
| 1054 |
+
def _parent(self):
|
| 1055 |
+
return self.file
|
| 1056 |
+
|
| 1057 |
+
def FindMethodByName(self, name):
|
| 1058 |
+
"""Searches for the specified method, and returns its descriptor.
|
| 1059 |
+
|
| 1060 |
+
Args:
|
| 1061 |
+
name (str): Name of the method.
|
| 1062 |
+
|
| 1063 |
+
Returns:
|
| 1064 |
+
MethodDescriptor: The descriptor for the requested method.
|
| 1065 |
+
|
| 1066 |
+
Raises:
|
| 1067 |
+
KeyError: if the method cannot be found in the service.
|
| 1068 |
+
"""
|
| 1069 |
+
return self.methods_by_name[name]
|
| 1070 |
+
|
| 1071 |
+
def CopyToProto(self, proto):
|
| 1072 |
+
"""Copies this to a descriptor_pb2.ServiceDescriptorProto.
|
| 1073 |
+
|
| 1074 |
+
Args:
|
| 1075 |
+
proto (descriptor_pb2.ServiceDescriptorProto): An empty descriptor proto.
|
| 1076 |
+
"""
|
| 1077 |
+
# This function is overridden to give a better doc comment.
|
| 1078 |
+
super(ServiceDescriptor, self).CopyToProto(proto)
|
| 1079 |
+
|
| 1080 |
+
|
| 1081 |
+
class MethodDescriptor(DescriptorBase):
|
| 1082 |
+
|
| 1083 |
+
"""Descriptor for a method in a service.
|
| 1084 |
+
|
| 1085 |
+
Attributes:
|
| 1086 |
+
name (str): Name of the method within the service.
|
| 1087 |
+
full_name (str): Full name of method.
|
| 1088 |
+
index (int): 0-indexed index of the method inside the service.
|
| 1089 |
+
containing_service (ServiceDescriptor): The service that contains this
|
| 1090 |
+
method.
|
| 1091 |
+
input_type (Descriptor): The descriptor of the message that this method
|
| 1092 |
+
accepts.
|
| 1093 |
+
output_type (Descriptor): The descriptor of the message that this method
|
| 1094 |
+
returns.
|
| 1095 |
+
client_streaming (bool): Whether this method uses client streaming.
|
| 1096 |
+
server_streaming (bool): Whether this method uses server streaming.
|
| 1097 |
+
options (descriptor_pb2.MethodOptions or None): Method options message, or
|
| 1098 |
+
None to use default method options.
|
| 1099 |
+
"""
|
| 1100 |
+
|
| 1101 |
+
if _USE_C_DESCRIPTORS:
|
| 1102 |
+
_C_DESCRIPTOR_CLASS = _message.MethodDescriptor
|
| 1103 |
+
|
| 1104 |
+
def __new__(cls,
|
| 1105 |
+
name,
|
| 1106 |
+
full_name,
|
| 1107 |
+
index,
|
| 1108 |
+
containing_service,
|
| 1109 |
+
input_type,
|
| 1110 |
+
output_type,
|
| 1111 |
+
client_streaming=False,
|
| 1112 |
+
server_streaming=False,
|
| 1113 |
+
options=None,
|
| 1114 |
+
serialized_options=None,
|
| 1115 |
+
create_key=None):
|
| 1116 |
+
_message.Message._CheckCalledFromGeneratedFile() # pylint: disable=protected-access
|
| 1117 |
+
return _message.default_pool.FindMethodByName(full_name)
|
| 1118 |
+
|
| 1119 |
+
def __init__(self,
|
| 1120 |
+
name,
|
| 1121 |
+
full_name,
|
| 1122 |
+
index,
|
| 1123 |
+
containing_service,
|
| 1124 |
+
input_type,
|
| 1125 |
+
output_type,
|
| 1126 |
+
client_streaming=False,
|
| 1127 |
+
server_streaming=False,
|
| 1128 |
+
options=None,
|
| 1129 |
+
serialized_options=None,
|
| 1130 |
+
create_key=None):
|
| 1131 |
+
"""The arguments are as described in the description of MethodDescriptor
|
| 1132 |
+
attributes above.
|
| 1133 |
+
|
| 1134 |
+
Note that containing_service may be None, and may be set later if necessary.
|
| 1135 |
+
"""
|
| 1136 |
+
if create_key is not _internal_create_key:
|
| 1137 |
+
_Deprecated('MethodDescriptor')
|
| 1138 |
+
|
| 1139 |
+
super(MethodDescriptor, self).__init__(
|
| 1140 |
+
containing_service.file if containing_service else None,
|
| 1141 |
+
options,
|
| 1142 |
+
serialized_options,
|
| 1143 |
+
'MethodOptions',
|
| 1144 |
+
)
|
| 1145 |
+
self.name = name
|
| 1146 |
+
self.full_name = full_name
|
| 1147 |
+
self.index = index
|
| 1148 |
+
self.containing_service = containing_service
|
| 1149 |
+
self.input_type = input_type
|
| 1150 |
+
self.output_type = output_type
|
| 1151 |
+
self.client_streaming = client_streaming
|
| 1152 |
+
self.server_streaming = server_streaming
|
| 1153 |
+
|
| 1154 |
+
@property
|
| 1155 |
+
def _parent(self):
|
| 1156 |
+
return self.containing_service
|
| 1157 |
+
|
| 1158 |
+
def CopyToProto(self, proto):
|
| 1159 |
+
"""Copies this to a descriptor_pb2.MethodDescriptorProto.
|
| 1160 |
+
|
| 1161 |
+
Args:
|
| 1162 |
+
proto (descriptor_pb2.MethodDescriptorProto): An empty descriptor proto.
|
| 1163 |
+
|
| 1164 |
+
Raises:
|
| 1165 |
+
Error: If self couldn't be serialized, due to too few constructor
|
| 1166 |
+
arguments.
|
| 1167 |
+
"""
|
| 1168 |
+
if self.containing_service is not None:
|
| 1169 |
+
from google.protobuf import descriptor_pb2
|
| 1170 |
+
service_proto = descriptor_pb2.ServiceDescriptorProto()
|
| 1171 |
+
self.containing_service.CopyToProto(service_proto)
|
| 1172 |
+
proto.CopyFrom(service_proto.method[self.index])
|
| 1173 |
+
else:
|
| 1174 |
+
raise Error('Descriptor does not contain a service.')
|
| 1175 |
+
|
| 1176 |
+
|
| 1177 |
+
class FileDescriptor(DescriptorBase):
|
| 1178 |
+
"""Descriptor for a file. Mimics the descriptor_pb2.FileDescriptorProto.
|
| 1179 |
+
|
| 1180 |
+
Note that :attr:`enum_types_by_name`, :attr:`extensions_by_name`, and
|
| 1181 |
+
:attr:`dependencies` fields are only set by the
|
| 1182 |
+
:py:mod:`google.protobuf.message_factory` module, and not by the generated
|
| 1183 |
+
proto code.
|
| 1184 |
+
|
| 1185 |
+
Attributes:
|
| 1186 |
+
name (str): Name of file, relative to root of source tree.
|
| 1187 |
+
package (str): Name of the package
|
| 1188 |
+
edition (Edition): Enum value indicating edition of the file
|
| 1189 |
+
serialized_pb (bytes): Byte string of serialized
|
| 1190 |
+
:class:`descriptor_pb2.FileDescriptorProto`.
|
| 1191 |
+
dependencies (list[FileDescriptor]): List of other :class:`FileDescriptor`
|
| 1192 |
+
objects this :class:`FileDescriptor` depends on.
|
| 1193 |
+
public_dependencies (list[FileDescriptor]): A subset of
|
| 1194 |
+
:attr:`dependencies`, which were declared as "public".
|
| 1195 |
+
message_types_by_name (dict(str, Descriptor)): Mapping from message names to
|
| 1196 |
+
their :class:`Descriptor`.
|
| 1197 |
+
enum_types_by_name (dict(str, EnumDescriptor)): Mapping from enum names to
|
| 1198 |
+
their :class:`EnumDescriptor`.
|
| 1199 |
+
extensions_by_name (dict(str, FieldDescriptor)): Mapping from extension
|
| 1200 |
+
names declared at file scope to their :class:`FieldDescriptor`.
|
| 1201 |
+
services_by_name (dict(str, ServiceDescriptor)): Mapping from services'
|
| 1202 |
+
names to their :class:`ServiceDescriptor`.
|
| 1203 |
+
pool (DescriptorPool): The pool this descriptor belongs to. When not passed
|
| 1204 |
+
to the constructor, the global default pool is used.
|
| 1205 |
+
"""
|
| 1206 |
+
|
| 1207 |
+
if _USE_C_DESCRIPTORS:
|
| 1208 |
+
_C_DESCRIPTOR_CLASS = _message.FileDescriptor
|
| 1209 |
+
|
| 1210 |
+
def __new__(
|
| 1211 |
+
cls,
|
| 1212 |
+
name,
|
| 1213 |
+
package,
|
| 1214 |
+
options=None,
|
| 1215 |
+
serialized_options=None,
|
| 1216 |
+
serialized_pb=None,
|
| 1217 |
+
dependencies=None,
|
| 1218 |
+
public_dependencies=None,
|
| 1219 |
+
syntax=None,
|
| 1220 |
+
edition=None,
|
| 1221 |
+
pool=None,
|
| 1222 |
+
create_key=None,
|
| 1223 |
+
):
|
| 1224 |
+
# FileDescriptor() is called from various places, not only from generated
|
| 1225 |
+
# files, to register dynamic proto files and messages.
|
| 1226 |
+
# pylint: disable=g-explicit-bool-comparison
|
| 1227 |
+
if serialized_pb:
|
| 1228 |
+
return _message.default_pool.AddSerializedFile(serialized_pb)
|
| 1229 |
+
else:
|
| 1230 |
+
return super(FileDescriptor, cls).__new__(cls)
|
| 1231 |
+
|
| 1232 |
+
def __init__(
|
| 1233 |
+
self,
|
| 1234 |
+
name,
|
| 1235 |
+
package,
|
| 1236 |
+
options=None,
|
| 1237 |
+
serialized_options=None,
|
| 1238 |
+
serialized_pb=None,
|
| 1239 |
+
dependencies=None,
|
| 1240 |
+
public_dependencies=None,
|
| 1241 |
+
syntax=None,
|
| 1242 |
+
edition=None,
|
| 1243 |
+
pool=None,
|
| 1244 |
+
create_key=None,
|
| 1245 |
+
):
|
| 1246 |
+
"""Constructor."""
|
| 1247 |
+
if create_key is not _internal_create_key:
|
| 1248 |
+
_Deprecated('FileDescriptor')
|
| 1249 |
+
|
| 1250 |
+
super(FileDescriptor, self).__init__(
|
| 1251 |
+
self, options, serialized_options, 'FileOptions'
|
| 1252 |
+
)
|
| 1253 |
+
|
| 1254 |
+
if edition and edition != 'EDITION_UNKNOWN':
|
| 1255 |
+
self._edition = edition
|
| 1256 |
+
elif syntax == 'proto3':
|
| 1257 |
+
self._edition = 'EDITION_PROTO3'
|
| 1258 |
+
else:
|
| 1259 |
+
self._edition = 'EDITION_PROTO2'
|
| 1260 |
+
|
| 1261 |
+
if pool is None:
|
| 1262 |
+
from google.protobuf import descriptor_pool
|
| 1263 |
+
pool = descriptor_pool.Default()
|
| 1264 |
+
self.pool = pool
|
| 1265 |
+
self.message_types_by_name = {}
|
| 1266 |
+
self.name = name
|
| 1267 |
+
self.package = package
|
| 1268 |
+
self.serialized_pb = serialized_pb
|
| 1269 |
+
|
| 1270 |
+
self.enum_types_by_name = {}
|
| 1271 |
+
self.extensions_by_name = {}
|
| 1272 |
+
self.services_by_name = {}
|
| 1273 |
+
self.dependencies = (dependencies or [])
|
| 1274 |
+
self.public_dependencies = (public_dependencies or [])
|
| 1275 |
+
|
| 1276 |
+
def CopyToProto(self, proto):
|
| 1277 |
+
"""Copies this to a descriptor_pb2.FileDescriptorProto.
|
| 1278 |
+
|
| 1279 |
+
Args:
|
| 1280 |
+
proto: An empty descriptor_pb2.FileDescriptorProto.
|
| 1281 |
+
"""
|
| 1282 |
+
proto.ParseFromString(self.serialized_pb)
|
| 1283 |
+
|
| 1284 |
+
@property
|
| 1285 |
+
def _parent(self):
|
| 1286 |
+
return None
|
| 1287 |
+
|
| 1288 |
+
|
| 1289 |
+
def _ParseOptions(message, string):
|
| 1290 |
+
"""Parses serialized options.
|
| 1291 |
+
|
| 1292 |
+
This helper function is used to parse serialized options in generated
|
| 1293 |
+
proto2 files. It must not be used outside proto2.
|
| 1294 |
+
"""
|
| 1295 |
+
message.ParseFromString(string)
|
| 1296 |
+
return message
|
| 1297 |
+
|
| 1298 |
+
|
| 1299 |
+
def _ToCamelCase(name):
|
| 1300 |
+
"""Converts name to camel-case and returns it."""
|
| 1301 |
+
capitalize_next = False
|
| 1302 |
+
result = []
|
| 1303 |
+
|
| 1304 |
+
for c in name:
|
| 1305 |
+
if c == '_':
|
| 1306 |
+
if result:
|
| 1307 |
+
capitalize_next = True
|
| 1308 |
+
elif capitalize_next:
|
| 1309 |
+
result.append(c.upper())
|
| 1310 |
+
capitalize_next = False
|
| 1311 |
+
else:
|
| 1312 |
+
result += c
|
| 1313 |
+
|
| 1314 |
+
# Lower-case the first letter.
|
| 1315 |
+
if result and result[0].isupper():
|
| 1316 |
+
result[0] = result[0].lower()
|
| 1317 |
+
return ''.join(result)
|
| 1318 |
+
|
| 1319 |
+
|
| 1320 |
+
def _OptionsOrNone(descriptor_proto):
|
| 1321 |
+
"""Returns the value of the field `options`, or None if it is not set."""
|
| 1322 |
+
if descriptor_proto.HasField('options'):
|
| 1323 |
+
return descriptor_proto.options
|
| 1324 |
+
else:
|
| 1325 |
+
return None
|
| 1326 |
+
|
| 1327 |
+
|
| 1328 |
+
def _ToJsonName(name):
|
| 1329 |
+
"""Converts name to Json name and returns it."""
|
| 1330 |
+
capitalize_next = False
|
| 1331 |
+
result = []
|
| 1332 |
+
|
| 1333 |
+
for c in name:
|
| 1334 |
+
if c == '_':
|
| 1335 |
+
capitalize_next = True
|
| 1336 |
+
elif capitalize_next:
|
| 1337 |
+
result.append(c.upper())
|
| 1338 |
+
capitalize_next = False
|
| 1339 |
+
else:
|
| 1340 |
+
result += c
|
| 1341 |
+
|
| 1342 |
+
return ''.join(result)
|
| 1343 |
+
|
| 1344 |
+
|
| 1345 |
+
def MakeDescriptor(
|
| 1346 |
+
desc_proto,
|
| 1347 |
+
package='',
|
| 1348 |
+
build_file_if_cpp=True,
|
| 1349 |
+
syntax=None,
|
| 1350 |
+
edition=None,
|
| 1351 |
+
file_desc=None,
|
| 1352 |
+
):
|
| 1353 |
+
"""Make a protobuf Descriptor given a DescriptorProto protobuf.
|
| 1354 |
+
|
| 1355 |
+
Handles nested descriptors. Note that this is limited to the scope of defining
|
| 1356 |
+
a message inside of another message. Composite fields can currently only be
|
| 1357 |
+
resolved if the message is defined in the same scope as the field.
|
| 1358 |
+
|
| 1359 |
+
Args:
|
| 1360 |
+
desc_proto: The descriptor_pb2.DescriptorProto protobuf message.
|
| 1361 |
+
package: Optional package name for the new message Descriptor (string).
|
| 1362 |
+
build_file_if_cpp: Update the C++ descriptor pool if api matches. Set to
|
| 1363 |
+
False on recursion, so no duplicates are created.
|
| 1364 |
+
syntax: The syntax/semantics that should be used. Set to "proto3" to get
|
| 1365 |
+
proto3 field presence semantics.
|
| 1366 |
+
edition: The edition that should be used if syntax is "edition".
|
| 1367 |
+
file_desc: A FileDescriptor to place this descriptor into.
|
| 1368 |
+
|
| 1369 |
+
Returns:
|
| 1370 |
+
A Descriptor for protobuf messages.
|
| 1371 |
+
"""
|
| 1372 |
+
# pylint: disable=g-import-not-at-top
|
| 1373 |
+
from google.protobuf import descriptor_pb2
|
| 1374 |
+
|
| 1375 |
+
# Generate a random name for this proto file to prevent conflicts with any
|
| 1376 |
+
# imported ones. We need to specify a file name so the descriptor pool
|
| 1377 |
+
# accepts our FileDescriptorProto, but it is not important what that file
|
| 1378 |
+
# name is actually set to.
|
| 1379 |
+
proto_name = binascii.hexlify(os.urandom(16)).decode('ascii')
|
| 1380 |
+
|
| 1381 |
+
if package:
|
| 1382 |
+
file_name = os.path.join(package.replace('.', '/'), proto_name + '.proto')
|
| 1383 |
+
else:
|
| 1384 |
+
file_name = proto_name + '.proto'
|
| 1385 |
+
|
| 1386 |
+
if api_implementation.Type() != 'python' and build_file_if_cpp:
|
| 1387 |
+
# The C++ implementation requires all descriptors to be backed by the same
|
| 1388 |
+
# definition in the C++ descriptor pool. To do this, we build a
|
| 1389 |
+
# FileDescriptorProto with the same definition as this descriptor and build
|
| 1390 |
+
# it into the pool.
|
| 1391 |
+
file_descriptor_proto = descriptor_pb2.FileDescriptorProto()
|
| 1392 |
+
file_descriptor_proto.message_type.add().MergeFrom(desc_proto)
|
| 1393 |
+
|
| 1394 |
+
if package:
|
| 1395 |
+
file_descriptor_proto.package = package
|
| 1396 |
+
file_descriptor_proto.name = file_name
|
| 1397 |
+
|
| 1398 |
+
_message.default_pool.Add(file_descriptor_proto)
|
| 1399 |
+
result = _message.default_pool.FindFileByName(file_descriptor_proto.name)
|
| 1400 |
+
|
| 1401 |
+
if _USE_C_DESCRIPTORS:
|
| 1402 |
+
return result.message_types_by_name[desc_proto.name]
|
| 1403 |
+
|
| 1404 |
+
if file_desc is None:
|
| 1405 |
+
file_desc = FileDescriptor(
|
| 1406 |
+
pool=None,
|
| 1407 |
+
name=file_name,
|
| 1408 |
+
package=package,
|
| 1409 |
+
syntax=syntax,
|
| 1410 |
+
edition=edition,
|
| 1411 |
+
options=None,
|
| 1412 |
+
serialized_pb='',
|
| 1413 |
+
dependencies=[],
|
| 1414 |
+
public_dependencies=[],
|
| 1415 |
+
create_key=_internal_create_key,
|
| 1416 |
+
)
|
| 1417 |
+
full_message_name = [desc_proto.name]
|
| 1418 |
+
if package: full_message_name.insert(0, package)
|
| 1419 |
+
|
| 1420 |
+
# Create Descriptors for enum types
|
| 1421 |
+
enum_types = {}
|
| 1422 |
+
for enum_proto in desc_proto.enum_type:
|
| 1423 |
+
full_name = '.'.join(full_message_name + [enum_proto.name])
|
| 1424 |
+
enum_desc = EnumDescriptor(
|
| 1425 |
+
enum_proto.name,
|
| 1426 |
+
full_name,
|
| 1427 |
+
None,
|
| 1428 |
+
[
|
| 1429 |
+
EnumValueDescriptor(
|
| 1430 |
+
enum_val.name,
|
| 1431 |
+
ii,
|
| 1432 |
+
enum_val.number,
|
| 1433 |
+
create_key=_internal_create_key,
|
| 1434 |
+
)
|
| 1435 |
+
for ii, enum_val in enumerate(enum_proto.value)
|
| 1436 |
+
],
|
| 1437 |
+
file=file_desc,
|
| 1438 |
+
create_key=_internal_create_key,
|
| 1439 |
+
)
|
| 1440 |
+
enum_types[full_name] = enum_desc
|
| 1441 |
+
|
| 1442 |
+
# Create Descriptors for nested types
|
| 1443 |
+
nested_types = {}
|
| 1444 |
+
for nested_proto in desc_proto.nested_type:
|
| 1445 |
+
full_name = '.'.join(full_message_name + [nested_proto.name])
|
| 1446 |
+
# Nested types are just those defined inside of the message, not all types
|
| 1447 |
+
# used by fields in the message, so no loops are possible here.
|
| 1448 |
+
nested_desc = MakeDescriptor(
|
| 1449 |
+
nested_proto,
|
| 1450 |
+
package='.'.join(full_message_name),
|
| 1451 |
+
build_file_if_cpp=False,
|
| 1452 |
+
syntax=syntax,
|
| 1453 |
+
edition=edition,
|
| 1454 |
+
file_desc=file_desc,
|
| 1455 |
+
)
|
| 1456 |
+
nested_types[full_name] = nested_desc
|
| 1457 |
+
|
| 1458 |
+
fields = []
|
| 1459 |
+
for field_proto in desc_proto.field:
|
| 1460 |
+
full_name = '.'.join(full_message_name + [field_proto.name])
|
| 1461 |
+
enum_desc = None
|
| 1462 |
+
nested_desc = None
|
| 1463 |
+
if field_proto.json_name:
|
| 1464 |
+
json_name = field_proto.json_name
|
| 1465 |
+
else:
|
| 1466 |
+
json_name = None
|
| 1467 |
+
if field_proto.HasField('type_name'):
|
| 1468 |
+
type_name = field_proto.type_name
|
| 1469 |
+
full_type_name = '.'.join(full_message_name +
|
| 1470 |
+
[type_name[type_name.rfind('.')+1:]])
|
| 1471 |
+
if full_type_name in nested_types:
|
| 1472 |
+
nested_desc = nested_types[full_type_name]
|
| 1473 |
+
elif full_type_name in enum_types:
|
| 1474 |
+
enum_desc = enum_types[full_type_name]
|
| 1475 |
+
# Else type_name references a non-local type, which isn't implemented
|
| 1476 |
+
field = FieldDescriptor(
|
| 1477 |
+
field_proto.name,
|
| 1478 |
+
full_name,
|
| 1479 |
+
field_proto.number - 1,
|
| 1480 |
+
field_proto.number,
|
| 1481 |
+
field_proto.type,
|
| 1482 |
+
FieldDescriptor.ProtoTypeToCppProtoType(field_proto.type),
|
| 1483 |
+
field_proto.label,
|
| 1484 |
+
None,
|
| 1485 |
+
nested_desc,
|
| 1486 |
+
enum_desc,
|
| 1487 |
+
None,
|
| 1488 |
+
False,
|
| 1489 |
+
None,
|
| 1490 |
+
options=_OptionsOrNone(field_proto),
|
| 1491 |
+
has_default_value=False,
|
| 1492 |
+
json_name=json_name,
|
| 1493 |
+
file=file_desc,
|
| 1494 |
+
create_key=_internal_create_key,
|
| 1495 |
+
)
|
| 1496 |
+
fields.append(field)
|
| 1497 |
+
|
| 1498 |
+
desc_name = '.'.join(full_message_name)
|
| 1499 |
+
return Descriptor(
|
| 1500 |
+
desc_proto.name,
|
| 1501 |
+
desc_name,
|
| 1502 |
+
None,
|
| 1503 |
+
None,
|
| 1504 |
+
fields,
|
| 1505 |
+
list(nested_types.values()),
|
| 1506 |
+
list(enum_types.values()),
|
| 1507 |
+
[],
|
| 1508 |
+
options=_OptionsOrNone(desc_proto),
|
| 1509 |
+
file=file_desc,
|
| 1510 |
+
create_key=_internal_create_key,
|
| 1511 |
+
)
|
.venv/lib/python3.11/site-packages/google/protobuf/descriptor_database.py
ADDED
|
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Protocol Buffers - Google's data interchange format
|
| 2 |
+
# Copyright 2008 Google Inc. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Use of this source code is governed by a BSD-style
|
| 5 |
+
# license that can be found in the LICENSE file or at
|
| 6 |
+
# https://developers.google.com/open-source/licenses/bsd
|
| 7 |
+
|
| 8 |
+
"""Provides a container for DescriptorProtos."""
|
| 9 |
+
|
| 10 |
+
__author__ = 'matthewtoia@google.com (Matt Toia)'
|
| 11 |
+
|
| 12 |
+
import warnings
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class Error(Exception):
|
| 16 |
+
pass
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class DescriptorDatabaseConflictingDefinitionError(Error):
|
| 20 |
+
"""Raised when a proto is added with the same name & different descriptor."""
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class DescriptorDatabase(object):
|
| 24 |
+
"""A container accepting FileDescriptorProtos and maps DescriptorProtos."""
|
| 25 |
+
|
| 26 |
+
def __init__(self):
|
| 27 |
+
self._file_desc_protos_by_file = {}
|
| 28 |
+
self._file_desc_protos_by_symbol = {}
|
| 29 |
+
|
| 30 |
+
def Add(self, file_desc_proto):
|
| 31 |
+
"""Adds the FileDescriptorProto and its types to this database.
|
| 32 |
+
|
| 33 |
+
Args:
|
| 34 |
+
file_desc_proto: The FileDescriptorProto to add.
|
| 35 |
+
Raises:
|
| 36 |
+
DescriptorDatabaseConflictingDefinitionError: if an attempt is made to
|
| 37 |
+
add a proto with the same name but different definition than an
|
| 38 |
+
existing proto in the database.
|
| 39 |
+
"""
|
| 40 |
+
proto_name = file_desc_proto.name
|
| 41 |
+
if proto_name not in self._file_desc_protos_by_file:
|
| 42 |
+
self._file_desc_protos_by_file[proto_name] = file_desc_proto
|
| 43 |
+
elif self._file_desc_protos_by_file[proto_name] != file_desc_proto:
|
| 44 |
+
raise DescriptorDatabaseConflictingDefinitionError(
|
| 45 |
+
'%s already added, but with different descriptor.' % proto_name)
|
| 46 |
+
else:
|
| 47 |
+
return
|
| 48 |
+
|
| 49 |
+
# Add all the top-level descriptors to the index.
|
| 50 |
+
package = file_desc_proto.package
|
| 51 |
+
for message in file_desc_proto.message_type:
|
| 52 |
+
for name in _ExtractSymbols(message, package):
|
| 53 |
+
self._AddSymbol(name, file_desc_proto)
|
| 54 |
+
for enum in file_desc_proto.enum_type:
|
| 55 |
+
self._AddSymbol(('.'.join((package, enum.name))), file_desc_proto)
|
| 56 |
+
for enum_value in enum.value:
|
| 57 |
+
self._file_desc_protos_by_symbol[
|
| 58 |
+
'.'.join((package, enum_value.name))] = file_desc_proto
|
| 59 |
+
for extension in file_desc_proto.extension:
|
| 60 |
+
self._AddSymbol(('.'.join((package, extension.name))), file_desc_proto)
|
| 61 |
+
for service in file_desc_proto.service:
|
| 62 |
+
self._AddSymbol(('.'.join((package, service.name))), file_desc_proto)
|
| 63 |
+
|
| 64 |
+
def FindFileByName(self, name):
|
| 65 |
+
"""Finds the file descriptor proto by file name.
|
| 66 |
+
|
| 67 |
+
Typically the file name is a relative path ending to a .proto file. The
|
| 68 |
+
proto with the given name will have to have been added to this database
|
| 69 |
+
using the Add method or else an error will be raised.
|
| 70 |
+
|
| 71 |
+
Args:
|
| 72 |
+
name: The file name to find.
|
| 73 |
+
|
| 74 |
+
Returns:
|
| 75 |
+
The file descriptor proto matching the name.
|
| 76 |
+
|
| 77 |
+
Raises:
|
| 78 |
+
KeyError if no file by the given name was added.
|
| 79 |
+
"""
|
| 80 |
+
|
| 81 |
+
return self._file_desc_protos_by_file[name]
|
| 82 |
+
|
| 83 |
+
def FindFileContainingSymbol(self, symbol):
|
| 84 |
+
"""Finds the file descriptor proto containing the specified symbol.
|
| 85 |
+
|
| 86 |
+
The symbol should be a fully qualified name including the file descriptor's
|
| 87 |
+
package and any containing messages. Some examples:
|
| 88 |
+
|
| 89 |
+
'some.package.name.Message'
|
| 90 |
+
'some.package.name.Message.NestedEnum'
|
| 91 |
+
'some.package.name.Message.some_field'
|
| 92 |
+
|
| 93 |
+
The file descriptor proto containing the specified symbol must be added to
|
| 94 |
+
this database using the Add method or else an error will be raised.
|
| 95 |
+
|
| 96 |
+
Args:
|
| 97 |
+
symbol: The fully qualified symbol name.
|
| 98 |
+
|
| 99 |
+
Returns:
|
| 100 |
+
The file descriptor proto containing the symbol.
|
| 101 |
+
|
| 102 |
+
Raises:
|
| 103 |
+
KeyError if no file contains the specified symbol.
|
| 104 |
+
"""
|
| 105 |
+
try:
|
| 106 |
+
return self._file_desc_protos_by_symbol[symbol]
|
| 107 |
+
except KeyError:
|
| 108 |
+
# Fields, enum values, and nested extensions are not in
|
| 109 |
+
# _file_desc_protos_by_symbol. Try to find the top level
|
| 110 |
+
# descriptor. Non-existent nested symbol under a valid top level
|
| 111 |
+
# descriptor can also be found. The behavior is the same with
|
| 112 |
+
# protobuf C++.
|
| 113 |
+
top_level, _, _ = symbol.rpartition('.')
|
| 114 |
+
try:
|
| 115 |
+
return self._file_desc_protos_by_symbol[top_level]
|
| 116 |
+
except KeyError:
|
| 117 |
+
# Raise the original symbol as a KeyError for better diagnostics.
|
| 118 |
+
raise KeyError(symbol)
|
| 119 |
+
|
| 120 |
+
def FindFileContainingExtension(self, extendee_name, extension_number):
|
| 121 |
+
# TODO: implement this API.
|
| 122 |
+
return None
|
| 123 |
+
|
| 124 |
+
def FindAllExtensionNumbers(self, extendee_name):
|
| 125 |
+
# TODO: implement this API.
|
| 126 |
+
return []
|
| 127 |
+
|
| 128 |
+
def _AddSymbol(self, name, file_desc_proto):
|
| 129 |
+
if name in self._file_desc_protos_by_symbol:
|
| 130 |
+
warn_msg = ('Conflict register for file "' + file_desc_proto.name +
|
| 131 |
+
'": ' + name +
|
| 132 |
+
' is already defined in file "' +
|
| 133 |
+
self._file_desc_protos_by_symbol[name].name + '"')
|
| 134 |
+
warnings.warn(warn_msg, RuntimeWarning)
|
| 135 |
+
self._file_desc_protos_by_symbol[name] = file_desc_proto
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def _ExtractSymbols(desc_proto, package):
|
| 139 |
+
"""Pulls out all the symbols from a descriptor proto.
|
| 140 |
+
|
| 141 |
+
Args:
|
| 142 |
+
desc_proto: The proto to extract symbols from.
|
| 143 |
+
package: The package containing the descriptor type.
|
| 144 |
+
|
| 145 |
+
Yields:
|
| 146 |
+
The fully qualified name found in the descriptor.
|
| 147 |
+
"""
|
| 148 |
+
message_name = package + '.' + desc_proto.name if package else desc_proto.name
|
| 149 |
+
yield message_name
|
| 150 |
+
for nested_type in desc_proto.nested_type:
|
| 151 |
+
for symbol in _ExtractSymbols(nested_type, message_name):
|
| 152 |
+
yield symbol
|
| 153 |
+
for enum_type in desc_proto.enum_type:
|
| 154 |
+
yield '.'.join((message_name, enum_type.name))
|
.venv/lib/python3.11/site-packages/google/protobuf/descriptor_pb2.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
.venv/lib/python3.11/site-packages/google/protobuf/descriptor_pool.py
ADDED
|
@@ -0,0 +1,1355 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Protocol Buffers - Google's data interchange format
|
| 2 |
+
# Copyright 2008 Google Inc. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Use of this source code is governed by a BSD-style
|
| 5 |
+
# license that can be found in the LICENSE file or at
|
| 6 |
+
# https://developers.google.com/open-source/licenses/bsd
|
| 7 |
+
|
| 8 |
+
"""Provides DescriptorPool to use as a container for proto2 descriptors.
|
| 9 |
+
|
| 10 |
+
The DescriptorPool is used in conjection with a DescriptorDatabase to maintain
|
| 11 |
+
a collection of protocol buffer descriptors for use when dynamically creating
|
| 12 |
+
message types at runtime.
|
| 13 |
+
|
| 14 |
+
For most applications protocol buffers should be used via modules generated by
|
| 15 |
+
the protocol buffer compiler tool. This should only be used when the type of
|
| 16 |
+
protocol buffers used in an application or library cannot be predetermined.
|
| 17 |
+
|
| 18 |
+
Below is a straightforward example on how to use this class::
|
| 19 |
+
|
| 20 |
+
pool = DescriptorPool()
|
| 21 |
+
file_descriptor_protos = [ ... ]
|
| 22 |
+
for file_descriptor_proto in file_descriptor_protos:
|
| 23 |
+
pool.Add(file_descriptor_proto)
|
| 24 |
+
my_message_descriptor = pool.FindMessageTypeByName('some.package.MessageType')
|
| 25 |
+
|
| 26 |
+
The message descriptor can be used in conjunction with the message_factory
|
| 27 |
+
module in order to create a protocol buffer class that can be encoded and
|
| 28 |
+
decoded.
|
| 29 |
+
|
| 30 |
+
If you want to get a Python class for the specified proto, use the
|
| 31 |
+
helper functions inside google.protobuf.message_factory
|
| 32 |
+
directly instead of this class.
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
__author__ = 'matthewtoia@google.com (Matt Toia)'
|
| 36 |
+
|
| 37 |
+
import collections
|
| 38 |
+
import threading
|
| 39 |
+
import warnings
|
| 40 |
+
|
| 41 |
+
from google.protobuf import descriptor
|
| 42 |
+
from google.protobuf import descriptor_database
|
| 43 |
+
from google.protobuf import text_encoding
|
| 44 |
+
from google.protobuf.internal import python_edition_defaults
|
| 45 |
+
from google.protobuf.internal import python_message
|
| 46 |
+
|
| 47 |
+
_USE_C_DESCRIPTORS = descriptor._USE_C_DESCRIPTORS # pylint: disable=protected-access
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def _NormalizeFullyQualifiedName(name):
|
| 51 |
+
"""Remove leading period from fully-qualified type name.
|
| 52 |
+
|
| 53 |
+
Due to b/13860351 in descriptor_database.py, types in the root namespace are
|
| 54 |
+
generated with a leading period. This function removes that prefix.
|
| 55 |
+
|
| 56 |
+
Args:
|
| 57 |
+
name (str): The fully-qualified symbol name.
|
| 58 |
+
|
| 59 |
+
Returns:
|
| 60 |
+
str: The normalized fully-qualified symbol name.
|
| 61 |
+
"""
|
| 62 |
+
return name.lstrip('.')
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def _OptionsOrNone(descriptor_proto):
|
| 66 |
+
"""Returns the value of the field `options`, or None if it is not set."""
|
| 67 |
+
if descriptor_proto.HasField('options'):
|
| 68 |
+
return descriptor_proto.options
|
| 69 |
+
else:
|
| 70 |
+
return None
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def _IsMessageSetExtension(field):
|
| 74 |
+
return (field.is_extension and
|
| 75 |
+
field.containing_type.has_options and
|
| 76 |
+
field.containing_type.GetOptions().message_set_wire_format and
|
| 77 |
+
field.type == descriptor.FieldDescriptor.TYPE_MESSAGE and
|
| 78 |
+
field.label == descriptor.FieldDescriptor.LABEL_OPTIONAL)
|
| 79 |
+
|
| 80 |
+
_edition_defaults_lock = threading.Lock()
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
class DescriptorPool(object):
|
| 84 |
+
"""A collection of protobufs dynamically constructed by descriptor protos."""
|
| 85 |
+
|
| 86 |
+
if _USE_C_DESCRIPTORS:
|
| 87 |
+
|
| 88 |
+
def __new__(cls, descriptor_db=None):
|
| 89 |
+
# pylint: disable=protected-access
|
| 90 |
+
return descriptor._message.DescriptorPool(descriptor_db)
|
| 91 |
+
|
| 92 |
+
def __init__(
|
| 93 |
+
self, descriptor_db=None, use_deprecated_legacy_json_field_conflicts=False
|
| 94 |
+
):
|
| 95 |
+
"""Initializes a Pool of proto buffs.
|
| 96 |
+
|
| 97 |
+
The descriptor_db argument to the constructor is provided to allow
|
| 98 |
+
specialized file descriptor proto lookup code to be triggered on demand. An
|
| 99 |
+
example would be an implementation which will read and compile a file
|
| 100 |
+
specified in a call to FindFileByName() and not require the call to Add()
|
| 101 |
+
at all. Results from this database will be cached internally here as well.
|
| 102 |
+
|
| 103 |
+
Args:
|
| 104 |
+
descriptor_db: A secondary source of file descriptors.
|
| 105 |
+
use_deprecated_legacy_json_field_conflicts: Unused, for compatibility with
|
| 106 |
+
C++.
|
| 107 |
+
"""
|
| 108 |
+
|
| 109 |
+
self._internal_db = descriptor_database.DescriptorDatabase()
|
| 110 |
+
self._descriptor_db = descriptor_db
|
| 111 |
+
self._descriptors = {}
|
| 112 |
+
self._enum_descriptors = {}
|
| 113 |
+
self._service_descriptors = {}
|
| 114 |
+
self._file_descriptors = {}
|
| 115 |
+
self._toplevel_extensions = {}
|
| 116 |
+
self._top_enum_values = {}
|
| 117 |
+
# We store extensions in two two-level mappings: The first key is the
|
| 118 |
+
# descriptor of the message being extended, the second key is the extension
|
| 119 |
+
# full name or its tag number.
|
| 120 |
+
self._extensions_by_name = collections.defaultdict(dict)
|
| 121 |
+
self._extensions_by_number = collections.defaultdict(dict)
|
| 122 |
+
self._serialized_edition_defaults = (
|
| 123 |
+
python_edition_defaults._PROTOBUF_INTERNAL_PYTHON_EDITION_DEFAULTS
|
| 124 |
+
)
|
| 125 |
+
self._edition_defaults = None
|
| 126 |
+
self._feature_cache = dict()
|
| 127 |
+
|
| 128 |
+
def _CheckConflictRegister(self, desc, desc_name, file_name):
|
| 129 |
+
"""Check if the descriptor name conflicts with another of the same name.
|
| 130 |
+
|
| 131 |
+
Args:
|
| 132 |
+
desc: Descriptor of a message, enum, service, extension or enum value.
|
| 133 |
+
desc_name (str): the full name of desc.
|
| 134 |
+
file_name (str): The file name of descriptor.
|
| 135 |
+
"""
|
| 136 |
+
for register, descriptor_type in [
|
| 137 |
+
(self._descriptors, descriptor.Descriptor),
|
| 138 |
+
(self._enum_descriptors, descriptor.EnumDescriptor),
|
| 139 |
+
(self._service_descriptors, descriptor.ServiceDescriptor),
|
| 140 |
+
(self._toplevel_extensions, descriptor.FieldDescriptor),
|
| 141 |
+
(self._top_enum_values, descriptor.EnumValueDescriptor)]:
|
| 142 |
+
if desc_name in register:
|
| 143 |
+
old_desc = register[desc_name]
|
| 144 |
+
if isinstance(old_desc, descriptor.EnumValueDescriptor):
|
| 145 |
+
old_file = old_desc.type.file.name
|
| 146 |
+
else:
|
| 147 |
+
old_file = old_desc.file.name
|
| 148 |
+
|
| 149 |
+
if not isinstance(desc, descriptor_type) or (
|
| 150 |
+
old_file != file_name):
|
| 151 |
+
error_msg = ('Conflict register for file "' + file_name +
|
| 152 |
+
'": ' + desc_name +
|
| 153 |
+
' is already defined in file "' +
|
| 154 |
+
old_file + '". Please fix the conflict by adding '
|
| 155 |
+
'package name on the proto file, or use different '
|
| 156 |
+
'name for the duplication.')
|
| 157 |
+
if isinstance(desc, descriptor.EnumValueDescriptor):
|
| 158 |
+
error_msg += ('\nNote: enum values appear as '
|
| 159 |
+
'siblings of the enum type instead of '
|
| 160 |
+
'children of it.')
|
| 161 |
+
|
| 162 |
+
raise TypeError(error_msg)
|
| 163 |
+
|
| 164 |
+
return
|
| 165 |
+
|
| 166 |
+
def Add(self, file_desc_proto):
|
| 167 |
+
"""Adds the FileDescriptorProto and its types to this pool.
|
| 168 |
+
|
| 169 |
+
Args:
|
| 170 |
+
file_desc_proto (FileDescriptorProto): The file descriptor to add.
|
| 171 |
+
"""
|
| 172 |
+
|
| 173 |
+
self._internal_db.Add(file_desc_proto)
|
| 174 |
+
|
| 175 |
+
def AddSerializedFile(self, serialized_file_desc_proto):
|
| 176 |
+
"""Adds the FileDescriptorProto and its types to this pool.
|
| 177 |
+
|
| 178 |
+
Args:
|
| 179 |
+
serialized_file_desc_proto (bytes): A bytes string, serialization of the
|
| 180 |
+
:class:`FileDescriptorProto` to add.
|
| 181 |
+
|
| 182 |
+
Returns:
|
| 183 |
+
FileDescriptor: Descriptor for the added file.
|
| 184 |
+
"""
|
| 185 |
+
|
| 186 |
+
# pylint: disable=g-import-not-at-top
|
| 187 |
+
from google.protobuf import descriptor_pb2
|
| 188 |
+
file_desc_proto = descriptor_pb2.FileDescriptorProto.FromString(
|
| 189 |
+
serialized_file_desc_proto)
|
| 190 |
+
file_desc = self._ConvertFileProtoToFileDescriptor(file_desc_proto)
|
| 191 |
+
file_desc.serialized_pb = serialized_file_desc_proto
|
| 192 |
+
return file_desc
|
| 193 |
+
|
| 194 |
+
# Never call this method. It is for internal usage only.
|
| 195 |
+
def _AddDescriptor(self, desc):
|
| 196 |
+
"""Adds a Descriptor to the pool, non-recursively.
|
| 197 |
+
|
| 198 |
+
If the Descriptor contains nested messages or enums, the caller must
|
| 199 |
+
explicitly register them. This method also registers the FileDescriptor
|
| 200 |
+
associated with the message.
|
| 201 |
+
|
| 202 |
+
Args:
|
| 203 |
+
desc: A Descriptor.
|
| 204 |
+
"""
|
| 205 |
+
if not isinstance(desc, descriptor.Descriptor):
|
| 206 |
+
raise TypeError('Expected instance of descriptor.Descriptor.')
|
| 207 |
+
|
| 208 |
+
self._CheckConflictRegister(desc, desc.full_name, desc.file.name)
|
| 209 |
+
|
| 210 |
+
self._descriptors[desc.full_name] = desc
|
| 211 |
+
self._AddFileDescriptor(desc.file)
|
| 212 |
+
|
| 213 |
+
# Never call this method. It is for internal usage only.
|
| 214 |
+
def _AddEnumDescriptor(self, enum_desc):
|
| 215 |
+
"""Adds an EnumDescriptor to the pool.
|
| 216 |
+
|
| 217 |
+
This method also registers the FileDescriptor associated with the enum.
|
| 218 |
+
|
| 219 |
+
Args:
|
| 220 |
+
enum_desc: An EnumDescriptor.
|
| 221 |
+
"""
|
| 222 |
+
|
| 223 |
+
if not isinstance(enum_desc, descriptor.EnumDescriptor):
|
| 224 |
+
raise TypeError('Expected instance of descriptor.EnumDescriptor.')
|
| 225 |
+
|
| 226 |
+
file_name = enum_desc.file.name
|
| 227 |
+
self._CheckConflictRegister(enum_desc, enum_desc.full_name, file_name)
|
| 228 |
+
self._enum_descriptors[enum_desc.full_name] = enum_desc
|
| 229 |
+
|
| 230 |
+
# Top enum values need to be indexed.
|
| 231 |
+
# Count the number of dots to see whether the enum is toplevel or nested
|
| 232 |
+
# in a message. We cannot use enum_desc.containing_type at this stage.
|
| 233 |
+
if enum_desc.file.package:
|
| 234 |
+
top_level = (enum_desc.full_name.count('.')
|
| 235 |
+
- enum_desc.file.package.count('.') == 1)
|
| 236 |
+
else:
|
| 237 |
+
top_level = enum_desc.full_name.count('.') == 0
|
| 238 |
+
if top_level:
|
| 239 |
+
file_name = enum_desc.file.name
|
| 240 |
+
package = enum_desc.file.package
|
| 241 |
+
for enum_value in enum_desc.values:
|
| 242 |
+
full_name = _NormalizeFullyQualifiedName(
|
| 243 |
+
'.'.join((package, enum_value.name)))
|
| 244 |
+
self._CheckConflictRegister(enum_value, full_name, file_name)
|
| 245 |
+
self._top_enum_values[full_name] = enum_value
|
| 246 |
+
self._AddFileDescriptor(enum_desc.file)
|
| 247 |
+
|
| 248 |
+
# Never call this method. It is for internal usage only.
|
| 249 |
+
def _AddServiceDescriptor(self, service_desc):
|
| 250 |
+
"""Adds a ServiceDescriptor to the pool.
|
| 251 |
+
|
| 252 |
+
Args:
|
| 253 |
+
service_desc: A ServiceDescriptor.
|
| 254 |
+
"""
|
| 255 |
+
|
| 256 |
+
if not isinstance(service_desc, descriptor.ServiceDescriptor):
|
| 257 |
+
raise TypeError('Expected instance of descriptor.ServiceDescriptor.')
|
| 258 |
+
|
| 259 |
+
self._CheckConflictRegister(service_desc, service_desc.full_name,
|
| 260 |
+
service_desc.file.name)
|
| 261 |
+
self._service_descriptors[service_desc.full_name] = service_desc
|
| 262 |
+
|
| 263 |
+
# Never call this method. It is for internal usage only.
|
| 264 |
+
def _AddExtensionDescriptor(self, extension):
|
| 265 |
+
"""Adds a FieldDescriptor describing an extension to the pool.
|
| 266 |
+
|
| 267 |
+
Args:
|
| 268 |
+
extension: A FieldDescriptor.
|
| 269 |
+
|
| 270 |
+
Raises:
|
| 271 |
+
AssertionError: when another extension with the same number extends the
|
| 272 |
+
same message.
|
| 273 |
+
TypeError: when the specified extension is not a
|
| 274 |
+
descriptor.FieldDescriptor.
|
| 275 |
+
"""
|
| 276 |
+
if not (isinstance(extension, descriptor.FieldDescriptor) and
|
| 277 |
+
extension.is_extension):
|
| 278 |
+
raise TypeError('Expected an extension descriptor.')
|
| 279 |
+
|
| 280 |
+
if extension.extension_scope is None:
|
| 281 |
+
self._CheckConflictRegister(
|
| 282 |
+
extension, extension.full_name, extension.file.name)
|
| 283 |
+
self._toplevel_extensions[extension.full_name] = extension
|
| 284 |
+
|
| 285 |
+
try:
|
| 286 |
+
existing_desc = self._extensions_by_number[
|
| 287 |
+
extension.containing_type][extension.number]
|
| 288 |
+
except KeyError:
|
| 289 |
+
pass
|
| 290 |
+
else:
|
| 291 |
+
if extension is not existing_desc:
|
| 292 |
+
raise AssertionError(
|
| 293 |
+
'Extensions "%s" and "%s" both try to extend message type "%s" '
|
| 294 |
+
'with field number %d.' %
|
| 295 |
+
(extension.full_name, existing_desc.full_name,
|
| 296 |
+
extension.containing_type.full_name, extension.number))
|
| 297 |
+
|
| 298 |
+
self._extensions_by_number[extension.containing_type][
|
| 299 |
+
extension.number] = extension
|
| 300 |
+
self._extensions_by_name[extension.containing_type][
|
| 301 |
+
extension.full_name] = extension
|
| 302 |
+
|
| 303 |
+
# Also register MessageSet extensions with the type name.
|
| 304 |
+
if _IsMessageSetExtension(extension):
|
| 305 |
+
self._extensions_by_name[extension.containing_type][
|
| 306 |
+
extension.message_type.full_name] = extension
|
| 307 |
+
|
| 308 |
+
if hasattr(extension.containing_type, '_concrete_class'):
|
| 309 |
+
python_message._AttachFieldHelpers(
|
| 310 |
+
extension.containing_type._concrete_class, extension)
|
| 311 |
+
|
| 312 |
+
# Never call this method. It is for internal usage only.
|
| 313 |
+
def _InternalAddFileDescriptor(self, file_desc):
|
| 314 |
+
"""Adds a FileDescriptor to the pool, non-recursively.
|
| 315 |
+
|
| 316 |
+
If the FileDescriptor contains messages or enums, the caller must explicitly
|
| 317 |
+
register them.
|
| 318 |
+
|
| 319 |
+
Args:
|
| 320 |
+
file_desc: A FileDescriptor.
|
| 321 |
+
"""
|
| 322 |
+
|
| 323 |
+
self._AddFileDescriptor(file_desc)
|
| 324 |
+
|
| 325 |
+
def _AddFileDescriptor(self, file_desc):
|
| 326 |
+
"""Adds a FileDescriptor to the pool, non-recursively.
|
| 327 |
+
|
| 328 |
+
If the FileDescriptor contains messages or enums, the caller must explicitly
|
| 329 |
+
register them.
|
| 330 |
+
|
| 331 |
+
Args:
|
| 332 |
+
file_desc: A FileDescriptor.
|
| 333 |
+
"""
|
| 334 |
+
|
| 335 |
+
if not isinstance(file_desc, descriptor.FileDescriptor):
|
| 336 |
+
raise TypeError('Expected instance of descriptor.FileDescriptor.')
|
| 337 |
+
self._file_descriptors[file_desc.name] = file_desc
|
| 338 |
+
|
| 339 |
+
def FindFileByName(self, file_name):
|
| 340 |
+
"""Gets a FileDescriptor by file name.
|
| 341 |
+
|
| 342 |
+
Args:
|
| 343 |
+
file_name (str): The path to the file to get a descriptor for.
|
| 344 |
+
|
| 345 |
+
Returns:
|
| 346 |
+
FileDescriptor: The descriptor for the named file.
|
| 347 |
+
|
| 348 |
+
Raises:
|
| 349 |
+
KeyError: if the file cannot be found in the pool.
|
| 350 |
+
"""
|
| 351 |
+
|
| 352 |
+
try:
|
| 353 |
+
return self._file_descriptors[file_name]
|
| 354 |
+
except KeyError:
|
| 355 |
+
pass
|
| 356 |
+
|
| 357 |
+
try:
|
| 358 |
+
file_proto = self._internal_db.FindFileByName(file_name)
|
| 359 |
+
except KeyError as error:
|
| 360 |
+
if self._descriptor_db:
|
| 361 |
+
file_proto = self._descriptor_db.FindFileByName(file_name)
|
| 362 |
+
else:
|
| 363 |
+
raise error
|
| 364 |
+
if not file_proto:
|
| 365 |
+
raise KeyError('Cannot find a file named %s' % file_name)
|
| 366 |
+
return self._ConvertFileProtoToFileDescriptor(file_proto)
|
| 367 |
+
|
| 368 |
+
def FindFileContainingSymbol(self, symbol):
|
| 369 |
+
"""Gets the FileDescriptor for the file containing the specified symbol.
|
| 370 |
+
|
| 371 |
+
Args:
|
| 372 |
+
symbol (str): The name of the symbol to search for.
|
| 373 |
+
|
| 374 |
+
Returns:
|
| 375 |
+
FileDescriptor: Descriptor for the file that contains the specified
|
| 376 |
+
symbol.
|
| 377 |
+
|
| 378 |
+
Raises:
|
| 379 |
+
KeyError: if the file cannot be found in the pool.
|
| 380 |
+
"""
|
| 381 |
+
|
| 382 |
+
symbol = _NormalizeFullyQualifiedName(symbol)
|
| 383 |
+
try:
|
| 384 |
+
return self._InternalFindFileContainingSymbol(symbol)
|
| 385 |
+
except KeyError:
|
| 386 |
+
pass
|
| 387 |
+
|
| 388 |
+
try:
|
| 389 |
+
# Try fallback database. Build and find again if possible.
|
| 390 |
+
self._FindFileContainingSymbolInDb(symbol)
|
| 391 |
+
return self._InternalFindFileContainingSymbol(symbol)
|
| 392 |
+
except KeyError:
|
| 393 |
+
raise KeyError('Cannot find a file containing %s' % symbol)
|
| 394 |
+
|
| 395 |
+
def _InternalFindFileContainingSymbol(self, symbol):
|
| 396 |
+
"""Gets the already built FileDescriptor containing the specified symbol.
|
| 397 |
+
|
| 398 |
+
Args:
|
| 399 |
+
symbol (str): The name of the symbol to search for.
|
| 400 |
+
|
| 401 |
+
Returns:
|
| 402 |
+
FileDescriptor: Descriptor for the file that contains the specified
|
| 403 |
+
symbol.
|
| 404 |
+
|
| 405 |
+
Raises:
|
| 406 |
+
KeyError: if the file cannot be found in the pool.
|
| 407 |
+
"""
|
| 408 |
+
try:
|
| 409 |
+
return self._descriptors[symbol].file
|
| 410 |
+
except KeyError:
|
| 411 |
+
pass
|
| 412 |
+
|
| 413 |
+
try:
|
| 414 |
+
return self._enum_descriptors[symbol].file
|
| 415 |
+
except KeyError:
|
| 416 |
+
pass
|
| 417 |
+
|
| 418 |
+
try:
|
| 419 |
+
return self._service_descriptors[symbol].file
|
| 420 |
+
except KeyError:
|
| 421 |
+
pass
|
| 422 |
+
|
| 423 |
+
try:
|
| 424 |
+
return self._top_enum_values[symbol].type.file
|
| 425 |
+
except KeyError:
|
| 426 |
+
pass
|
| 427 |
+
|
| 428 |
+
try:
|
| 429 |
+
return self._toplevel_extensions[symbol].file
|
| 430 |
+
except KeyError:
|
| 431 |
+
pass
|
| 432 |
+
|
| 433 |
+
# Try fields, enum values and nested extensions inside a message.
|
| 434 |
+
top_name, _, sub_name = symbol.rpartition('.')
|
| 435 |
+
try:
|
| 436 |
+
message = self.FindMessageTypeByName(top_name)
|
| 437 |
+
assert (sub_name in message.extensions_by_name or
|
| 438 |
+
sub_name in message.fields_by_name or
|
| 439 |
+
sub_name in message.enum_values_by_name)
|
| 440 |
+
return message.file
|
| 441 |
+
except (KeyError, AssertionError):
|
| 442 |
+
raise KeyError('Cannot find a file containing %s' % symbol)
|
| 443 |
+
|
| 444 |
+
def FindMessageTypeByName(self, full_name):
|
| 445 |
+
"""Loads the named descriptor from the pool.
|
| 446 |
+
|
| 447 |
+
Args:
|
| 448 |
+
full_name (str): The full name of the descriptor to load.
|
| 449 |
+
|
| 450 |
+
Returns:
|
| 451 |
+
Descriptor: The descriptor for the named type.
|
| 452 |
+
|
| 453 |
+
Raises:
|
| 454 |
+
KeyError: if the message cannot be found in the pool.
|
| 455 |
+
"""
|
| 456 |
+
|
| 457 |
+
full_name = _NormalizeFullyQualifiedName(full_name)
|
| 458 |
+
if full_name not in self._descriptors:
|
| 459 |
+
self._FindFileContainingSymbolInDb(full_name)
|
| 460 |
+
return self._descriptors[full_name]
|
| 461 |
+
|
| 462 |
+
def FindEnumTypeByName(self, full_name):
|
| 463 |
+
"""Loads the named enum descriptor from the pool.
|
| 464 |
+
|
| 465 |
+
Args:
|
| 466 |
+
full_name (str): The full name of the enum descriptor to load.
|
| 467 |
+
|
| 468 |
+
Returns:
|
| 469 |
+
EnumDescriptor: The enum descriptor for the named type.
|
| 470 |
+
|
| 471 |
+
Raises:
|
| 472 |
+
KeyError: if the enum cannot be found in the pool.
|
| 473 |
+
"""
|
| 474 |
+
|
| 475 |
+
full_name = _NormalizeFullyQualifiedName(full_name)
|
| 476 |
+
if full_name not in self._enum_descriptors:
|
| 477 |
+
self._FindFileContainingSymbolInDb(full_name)
|
| 478 |
+
return self._enum_descriptors[full_name]
|
| 479 |
+
|
| 480 |
+
def FindFieldByName(self, full_name):
|
| 481 |
+
"""Loads the named field descriptor from the pool.
|
| 482 |
+
|
| 483 |
+
Args:
|
| 484 |
+
full_name (str): The full name of the field descriptor to load.
|
| 485 |
+
|
| 486 |
+
Returns:
|
| 487 |
+
FieldDescriptor: The field descriptor for the named field.
|
| 488 |
+
|
| 489 |
+
Raises:
|
| 490 |
+
KeyError: if the field cannot be found in the pool.
|
| 491 |
+
"""
|
| 492 |
+
full_name = _NormalizeFullyQualifiedName(full_name)
|
| 493 |
+
message_name, _, field_name = full_name.rpartition('.')
|
| 494 |
+
message_descriptor = self.FindMessageTypeByName(message_name)
|
| 495 |
+
return message_descriptor.fields_by_name[field_name]
|
| 496 |
+
|
| 497 |
+
def FindOneofByName(self, full_name):
|
| 498 |
+
"""Loads the named oneof descriptor from the pool.
|
| 499 |
+
|
| 500 |
+
Args:
|
| 501 |
+
full_name (str): The full name of the oneof descriptor to load.
|
| 502 |
+
|
| 503 |
+
Returns:
|
| 504 |
+
OneofDescriptor: The oneof descriptor for the named oneof.
|
| 505 |
+
|
| 506 |
+
Raises:
|
| 507 |
+
KeyError: if the oneof cannot be found in the pool.
|
| 508 |
+
"""
|
| 509 |
+
full_name = _NormalizeFullyQualifiedName(full_name)
|
| 510 |
+
message_name, _, oneof_name = full_name.rpartition('.')
|
| 511 |
+
message_descriptor = self.FindMessageTypeByName(message_name)
|
| 512 |
+
return message_descriptor.oneofs_by_name[oneof_name]
|
| 513 |
+
|
| 514 |
+
def FindExtensionByName(self, full_name):
|
| 515 |
+
"""Loads the named extension descriptor from the pool.
|
| 516 |
+
|
| 517 |
+
Args:
|
| 518 |
+
full_name (str): The full name of the extension descriptor to load.
|
| 519 |
+
|
| 520 |
+
Returns:
|
| 521 |
+
FieldDescriptor: The field descriptor for the named extension.
|
| 522 |
+
|
| 523 |
+
Raises:
|
| 524 |
+
KeyError: if the extension cannot be found in the pool.
|
| 525 |
+
"""
|
| 526 |
+
full_name = _NormalizeFullyQualifiedName(full_name)
|
| 527 |
+
try:
|
| 528 |
+
# The proto compiler does not give any link between the FileDescriptor
|
| 529 |
+
# and top-level extensions unless the FileDescriptorProto is added to
|
| 530 |
+
# the DescriptorDatabase, but this can impact memory usage.
|
| 531 |
+
# So we registered these extensions by name explicitly.
|
| 532 |
+
return self._toplevel_extensions[full_name]
|
| 533 |
+
except KeyError:
|
| 534 |
+
pass
|
| 535 |
+
message_name, _, extension_name = full_name.rpartition('.')
|
| 536 |
+
try:
|
| 537 |
+
# Most extensions are nested inside a message.
|
| 538 |
+
scope = self.FindMessageTypeByName(message_name)
|
| 539 |
+
except KeyError:
|
| 540 |
+
# Some extensions are defined at file scope.
|
| 541 |
+
scope = self._FindFileContainingSymbolInDb(full_name)
|
| 542 |
+
return scope.extensions_by_name[extension_name]
|
| 543 |
+
|
| 544 |
+
def FindExtensionByNumber(self, message_descriptor, number):
|
| 545 |
+
"""Gets the extension of the specified message with the specified number.
|
| 546 |
+
|
| 547 |
+
Extensions have to be registered to this pool by calling :func:`Add` or
|
| 548 |
+
:func:`AddExtensionDescriptor`.
|
| 549 |
+
|
| 550 |
+
Args:
|
| 551 |
+
message_descriptor (Descriptor): descriptor of the extended message.
|
| 552 |
+
number (int): Number of the extension field.
|
| 553 |
+
|
| 554 |
+
Returns:
|
| 555 |
+
FieldDescriptor: The descriptor for the extension.
|
| 556 |
+
|
| 557 |
+
Raises:
|
| 558 |
+
KeyError: when no extension with the given number is known for the
|
| 559 |
+
specified message.
|
| 560 |
+
"""
|
| 561 |
+
try:
|
| 562 |
+
return self._extensions_by_number[message_descriptor][number]
|
| 563 |
+
except KeyError:
|
| 564 |
+
self._TryLoadExtensionFromDB(message_descriptor, number)
|
| 565 |
+
return self._extensions_by_number[message_descriptor][number]
|
| 566 |
+
|
| 567 |
+
def FindAllExtensions(self, message_descriptor):
|
| 568 |
+
"""Gets all the known extensions of a given message.
|
| 569 |
+
|
| 570 |
+
Extensions have to be registered to this pool by build related
|
| 571 |
+
:func:`Add` or :func:`AddExtensionDescriptor`.
|
| 572 |
+
|
| 573 |
+
Args:
|
| 574 |
+
message_descriptor (Descriptor): Descriptor of the extended message.
|
| 575 |
+
|
| 576 |
+
Returns:
|
| 577 |
+
list[FieldDescriptor]: Field descriptors describing the extensions.
|
| 578 |
+
"""
|
| 579 |
+
# Fallback to descriptor db if FindAllExtensionNumbers is provided.
|
| 580 |
+
if self._descriptor_db and hasattr(
|
| 581 |
+
self._descriptor_db, 'FindAllExtensionNumbers'):
|
| 582 |
+
full_name = message_descriptor.full_name
|
| 583 |
+
all_numbers = self._descriptor_db.FindAllExtensionNumbers(full_name)
|
| 584 |
+
for number in all_numbers:
|
| 585 |
+
if number in self._extensions_by_number[message_descriptor]:
|
| 586 |
+
continue
|
| 587 |
+
self._TryLoadExtensionFromDB(message_descriptor, number)
|
| 588 |
+
|
| 589 |
+
return list(self._extensions_by_number[message_descriptor].values())
|
| 590 |
+
|
| 591 |
+
def _TryLoadExtensionFromDB(self, message_descriptor, number):
|
| 592 |
+
"""Try to Load extensions from descriptor db.
|
| 593 |
+
|
| 594 |
+
Args:
|
| 595 |
+
message_descriptor: descriptor of the extended message.
|
| 596 |
+
number: the extension number that needs to be loaded.
|
| 597 |
+
"""
|
| 598 |
+
if not self._descriptor_db:
|
| 599 |
+
return
|
| 600 |
+
# Only supported when FindFileContainingExtension is provided.
|
| 601 |
+
if not hasattr(
|
| 602 |
+
self._descriptor_db, 'FindFileContainingExtension'):
|
| 603 |
+
return
|
| 604 |
+
|
| 605 |
+
full_name = message_descriptor.full_name
|
| 606 |
+
file_proto = self._descriptor_db.FindFileContainingExtension(
|
| 607 |
+
full_name, number)
|
| 608 |
+
|
| 609 |
+
if file_proto is None:
|
| 610 |
+
return
|
| 611 |
+
|
| 612 |
+
try:
|
| 613 |
+
self._ConvertFileProtoToFileDescriptor(file_proto)
|
| 614 |
+
except:
|
| 615 |
+
warn_msg = ('Unable to load proto file %s for extension number %d.' %
|
| 616 |
+
(file_proto.name, number))
|
| 617 |
+
warnings.warn(warn_msg, RuntimeWarning)
|
| 618 |
+
|
| 619 |
+
def FindServiceByName(self, full_name):
|
| 620 |
+
"""Loads the named service descriptor from the pool.
|
| 621 |
+
|
| 622 |
+
Args:
|
| 623 |
+
full_name (str): The full name of the service descriptor to load.
|
| 624 |
+
|
| 625 |
+
Returns:
|
| 626 |
+
ServiceDescriptor: The service descriptor for the named service.
|
| 627 |
+
|
| 628 |
+
Raises:
|
| 629 |
+
KeyError: if the service cannot be found in the pool.
|
| 630 |
+
"""
|
| 631 |
+
full_name = _NormalizeFullyQualifiedName(full_name)
|
| 632 |
+
if full_name not in self._service_descriptors:
|
| 633 |
+
self._FindFileContainingSymbolInDb(full_name)
|
| 634 |
+
return self._service_descriptors[full_name]
|
| 635 |
+
|
| 636 |
+
def FindMethodByName(self, full_name):
|
| 637 |
+
"""Loads the named service method descriptor from the pool.
|
| 638 |
+
|
| 639 |
+
Args:
|
| 640 |
+
full_name (str): The full name of the method descriptor to load.
|
| 641 |
+
|
| 642 |
+
Returns:
|
| 643 |
+
MethodDescriptor: The method descriptor for the service method.
|
| 644 |
+
|
| 645 |
+
Raises:
|
| 646 |
+
KeyError: if the method cannot be found in the pool.
|
| 647 |
+
"""
|
| 648 |
+
full_name = _NormalizeFullyQualifiedName(full_name)
|
| 649 |
+
service_name, _, method_name = full_name.rpartition('.')
|
| 650 |
+
service_descriptor = self.FindServiceByName(service_name)
|
| 651 |
+
return service_descriptor.methods_by_name[method_name]
|
| 652 |
+
|
| 653 |
+
def SetFeatureSetDefaults(self, defaults):
|
| 654 |
+
"""Sets the default feature mappings used during the build.
|
| 655 |
+
|
| 656 |
+
Args:
|
| 657 |
+
defaults: a FeatureSetDefaults message containing the new mappings.
|
| 658 |
+
"""
|
| 659 |
+
if self._edition_defaults is not None:
|
| 660 |
+
raise ValueError(
|
| 661 |
+
"Feature set defaults can't be changed once the pool has started"
|
| 662 |
+
' building!'
|
| 663 |
+
)
|
| 664 |
+
|
| 665 |
+
# pylint: disable=g-import-not-at-top
|
| 666 |
+
from google.protobuf import descriptor_pb2
|
| 667 |
+
|
| 668 |
+
if not isinstance(defaults, descriptor_pb2.FeatureSetDefaults):
|
| 669 |
+
raise TypeError('SetFeatureSetDefaults called with invalid type')
|
| 670 |
+
|
| 671 |
+
|
| 672 |
+
if defaults.minimum_edition > defaults.maximum_edition:
|
| 673 |
+
raise ValueError(
|
| 674 |
+
'Invalid edition range %s to %s'
|
| 675 |
+
% (
|
| 676 |
+
descriptor_pb2.Edition.Name(defaults.minimum_edition),
|
| 677 |
+
descriptor_pb2.Edition.Name(defaults.maximum_edition),
|
| 678 |
+
)
|
| 679 |
+
)
|
| 680 |
+
|
| 681 |
+
prev_edition = descriptor_pb2.Edition.EDITION_UNKNOWN
|
| 682 |
+
for d in defaults.defaults:
|
| 683 |
+
if d.edition == descriptor_pb2.Edition.EDITION_UNKNOWN:
|
| 684 |
+
raise ValueError('Invalid edition EDITION_UNKNOWN specified')
|
| 685 |
+
if prev_edition >= d.edition:
|
| 686 |
+
raise ValueError(
|
| 687 |
+
'Feature set defaults are not strictly increasing. %s is greater'
|
| 688 |
+
' than or equal to %s'
|
| 689 |
+
% (
|
| 690 |
+
descriptor_pb2.Edition.Name(prev_edition),
|
| 691 |
+
descriptor_pb2.Edition.Name(d.edition),
|
| 692 |
+
)
|
| 693 |
+
)
|
| 694 |
+
prev_edition = d.edition
|
| 695 |
+
self._edition_defaults = defaults
|
| 696 |
+
|
| 697 |
+
def _CreateDefaultFeatures(self, edition):
|
| 698 |
+
"""Creates a FeatureSet message with defaults for a specific edition.
|
| 699 |
+
|
| 700 |
+
Args:
|
| 701 |
+
edition: the edition to generate defaults for.
|
| 702 |
+
|
| 703 |
+
Returns:
|
| 704 |
+
A FeatureSet message with defaults for a specific edition.
|
| 705 |
+
"""
|
| 706 |
+
# pylint: disable=g-import-not-at-top
|
| 707 |
+
from google.protobuf import descriptor_pb2
|
| 708 |
+
|
| 709 |
+
with _edition_defaults_lock:
|
| 710 |
+
if not self._edition_defaults:
|
| 711 |
+
self._edition_defaults = descriptor_pb2.FeatureSetDefaults()
|
| 712 |
+
self._edition_defaults.ParseFromString(
|
| 713 |
+
self._serialized_edition_defaults
|
| 714 |
+
)
|
| 715 |
+
|
| 716 |
+
if edition < self._edition_defaults.minimum_edition:
|
| 717 |
+
raise TypeError(
|
| 718 |
+
'Edition %s is earlier than the minimum supported edition %s!'
|
| 719 |
+
% (
|
| 720 |
+
descriptor_pb2.Edition.Name(edition),
|
| 721 |
+
descriptor_pb2.Edition.Name(
|
| 722 |
+
self._edition_defaults.minimum_edition
|
| 723 |
+
),
|
| 724 |
+
)
|
| 725 |
+
)
|
| 726 |
+
if edition > self._edition_defaults.maximum_edition:
|
| 727 |
+
raise TypeError(
|
| 728 |
+
'Edition %s is later than the maximum supported edition %s!'
|
| 729 |
+
% (
|
| 730 |
+
descriptor_pb2.Edition.Name(edition),
|
| 731 |
+
descriptor_pb2.Edition.Name(
|
| 732 |
+
self._edition_defaults.maximum_edition
|
| 733 |
+
),
|
| 734 |
+
)
|
| 735 |
+
)
|
| 736 |
+
found = None
|
| 737 |
+
for d in self._edition_defaults.defaults:
|
| 738 |
+
if d.edition > edition:
|
| 739 |
+
break
|
| 740 |
+
found = d
|
| 741 |
+
if found is None:
|
| 742 |
+
raise TypeError(
|
| 743 |
+
'No valid default found for edition %s!'
|
| 744 |
+
% descriptor_pb2.Edition.Name(edition)
|
| 745 |
+
)
|
| 746 |
+
|
| 747 |
+
defaults = descriptor_pb2.FeatureSet()
|
| 748 |
+
defaults.CopyFrom(found.fixed_features)
|
| 749 |
+
defaults.MergeFrom(found.overridable_features)
|
| 750 |
+
return defaults
|
| 751 |
+
|
| 752 |
+
def _InternFeatures(self, features):
|
| 753 |
+
serialized = features.SerializeToString()
|
| 754 |
+
with _edition_defaults_lock:
|
| 755 |
+
cached = self._feature_cache.get(serialized)
|
| 756 |
+
if cached is None:
|
| 757 |
+
self._feature_cache[serialized] = features
|
| 758 |
+
cached = features
|
| 759 |
+
return cached
|
| 760 |
+
|
| 761 |
+
def _FindFileContainingSymbolInDb(self, symbol):
|
| 762 |
+
"""Finds the file in descriptor DB containing the specified symbol.
|
| 763 |
+
|
| 764 |
+
Args:
|
| 765 |
+
symbol (str): The name of the symbol to search for.
|
| 766 |
+
|
| 767 |
+
Returns:
|
| 768 |
+
FileDescriptor: The file that contains the specified symbol.
|
| 769 |
+
|
| 770 |
+
Raises:
|
| 771 |
+
KeyError: if the file cannot be found in the descriptor database.
|
| 772 |
+
"""
|
| 773 |
+
try:
|
| 774 |
+
file_proto = self._internal_db.FindFileContainingSymbol(symbol)
|
| 775 |
+
except KeyError as error:
|
| 776 |
+
if self._descriptor_db:
|
| 777 |
+
file_proto = self._descriptor_db.FindFileContainingSymbol(symbol)
|
| 778 |
+
else:
|
| 779 |
+
raise error
|
| 780 |
+
if not file_proto:
|
| 781 |
+
raise KeyError('Cannot find a file containing %s' % symbol)
|
| 782 |
+
return self._ConvertFileProtoToFileDescriptor(file_proto)
|
| 783 |
+
|
| 784 |
+
def _ConvertFileProtoToFileDescriptor(self, file_proto):
|
| 785 |
+
"""Creates a FileDescriptor from a proto or returns a cached copy.
|
| 786 |
+
|
| 787 |
+
This method also has the side effect of loading all the symbols found in
|
| 788 |
+
the file into the appropriate dictionaries in the pool.
|
| 789 |
+
|
| 790 |
+
Args:
|
| 791 |
+
file_proto: The proto to convert.
|
| 792 |
+
|
| 793 |
+
Returns:
|
| 794 |
+
A FileDescriptor matching the passed in proto.
|
| 795 |
+
"""
|
| 796 |
+
if file_proto.name not in self._file_descriptors:
|
| 797 |
+
built_deps = list(self._GetDeps(file_proto.dependency))
|
| 798 |
+
direct_deps = [self.FindFileByName(n) for n in file_proto.dependency]
|
| 799 |
+
public_deps = [direct_deps[i] for i in file_proto.public_dependency]
|
| 800 |
+
|
| 801 |
+
# pylint: disable=g-import-not-at-top
|
| 802 |
+
from google.protobuf import descriptor_pb2
|
| 803 |
+
|
| 804 |
+
file_descriptor = descriptor.FileDescriptor(
|
| 805 |
+
pool=self,
|
| 806 |
+
name=file_proto.name,
|
| 807 |
+
package=file_proto.package,
|
| 808 |
+
syntax=file_proto.syntax,
|
| 809 |
+
edition=descriptor_pb2.Edition.Name(file_proto.edition),
|
| 810 |
+
options=_OptionsOrNone(file_proto),
|
| 811 |
+
serialized_pb=file_proto.SerializeToString(),
|
| 812 |
+
dependencies=direct_deps,
|
| 813 |
+
public_dependencies=public_deps,
|
| 814 |
+
# pylint: disable=protected-access
|
| 815 |
+
create_key=descriptor._internal_create_key,
|
| 816 |
+
)
|
| 817 |
+
scope = {}
|
| 818 |
+
|
| 819 |
+
# This loop extracts all the message and enum types from all the
|
| 820 |
+
# dependencies of the file_proto. This is necessary to create the
|
| 821 |
+
# scope of available message types when defining the passed in
|
| 822 |
+
# file proto.
|
| 823 |
+
for dependency in built_deps:
|
| 824 |
+
scope.update(self._ExtractSymbols(
|
| 825 |
+
dependency.message_types_by_name.values()))
|
| 826 |
+
scope.update((_PrefixWithDot(enum.full_name), enum)
|
| 827 |
+
for enum in dependency.enum_types_by_name.values())
|
| 828 |
+
|
| 829 |
+
for message_type in file_proto.message_type:
|
| 830 |
+
message_desc = self._ConvertMessageDescriptor(
|
| 831 |
+
message_type, file_proto.package, file_descriptor, scope,
|
| 832 |
+
file_proto.syntax)
|
| 833 |
+
file_descriptor.message_types_by_name[message_desc.name] = (
|
| 834 |
+
message_desc)
|
| 835 |
+
|
| 836 |
+
for enum_type in file_proto.enum_type:
|
| 837 |
+
file_descriptor.enum_types_by_name[enum_type.name] = (
|
| 838 |
+
self._ConvertEnumDescriptor(enum_type, file_proto.package,
|
| 839 |
+
file_descriptor, None, scope, True))
|
| 840 |
+
|
| 841 |
+
for index, extension_proto in enumerate(file_proto.extension):
|
| 842 |
+
extension_desc = self._MakeFieldDescriptor(
|
| 843 |
+
extension_proto, file_proto.package, index, file_descriptor,
|
| 844 |
+
is_extension=True)
|
| 845 |
+
extension_desc.containing_type = self._GetTypeFromScope(
|
| 846 |
+
file_descriptor.package, extension_proto.extendee, scope)
|
| 847 |
+
self._SetFieldType(extension_proto, extension_desc,
|
| 848 |
+
file_descriptor.package, scope)
|
| 849 |
+
file_descriptor.extensions_by_name[extension_desc.name] = (
|
| 850 |
+
extension_desc)
|
| 851 |
+
|
| 852 |
+
for desc_proto in file_proto.message_type:
|
| 853 |
+
self._SetAllFieldTypes(file_proto.package, desc_proto, scope)
|
| 854 |
+
|
| 855 |
+
if file_proto.package:
|
| 856 |
+
desc_proto_prefix = _PrefixWithDot(file_proto.package)
|
| 857 |
+
else:
|
| 858 |
+
desc_proto_prefix = ''
|
| 859 |
+
|
| 860 |
+
for desc_proto in file_proto.message_type:
|
| 861 |
+
desc = self._GetTypeFromScope(
|
| 862 |
+
desc_proto_prefix, desc_proto.name, scope)
|
| 863 |
+
file_descriptor.message_types_by_name[desc_proto.name] = desc
|
| 864 |
+
|
| 865 |
+
for index, service_proto in enumerate(file_proto.service):
|
| 866 |
+
file_descriptor.services_by_name[service_proto.name] = (
|
| 867 |
+
self._MakeServiceDescriptor(service_proto, index, scope,
|
| 868 |
+
file_proto.package, file_descriptor))
|
| 869 |
+
|
| 870 |
+
self._file_descriptors[file_proto.name] = file_descriptor
|
| 871 |
+
|
| 872 |
+
# Add extensions to the pool
|
| 873 |
+
def AddExtensionForNested(message_type):
|
| 874 |
+
for nested in message_type.nested_types:
|
| 875 |
+
AddExtensionForNested(nested)
|
| 876 |
+
for extension in message_type.extensions:
|
| 877 |
+
self._AddExtensionDescriptor(extension)
|
| 878 |
+
|
| 879 |
+
file_desc = self._file_descriptors[file_proto.name]
|
| 880 |
+
for extension in file_desc.extensions_by_name.values():
|
| 881 |
+
self._AddExtensionDescriptor(extension)
|
| 882 |
+
for message_type in file_desc.message_types_by_name.values():
|
| 883 |
+
AddExtensionForNested(message_type)
|
| 884 |
+
|
| 885 |
+
return file_desc
|
| 886 |
+
|
| 887 |
+
def _ConvertMessageDescriptor(self, desc_proto, package=None, file_desc=None,
|
| 888 |
+
scope=None, syntax=None):
|
| 889 |
+
"""Adds the proto to the pool in the specified package.
|
| 890 |
+
|
| 891 |
+
Args:
|
| 892 |
+
desc_proto: The descriptor_pb2.DescriptorProto protobuf message.
|
| 893 |
+
package: The package the proto should be located in.
|
| 894 |
+
file_desc: The file containing this message.
|
| 895 |
+
scope: Dict mapping short and full symbols to message and enum types.
|
| 896 |
+
syntax: string indicating syntax of the file ("proto2" or "proto3")
|
| 897 |
+
|
| 898 |
+
Returns:
|
| 899 |
+
The added descriptor.
|
| 900 |
+
"""
|
| 901 |
+
|
| 902 |
+
if package:
|
| 903 |
+
desc_name = '.'.join((package, desc_proto.name))
|
| 904 |
+
else:
|
| 905 |
+
desc_name = desc_proto.name
|
| 906 |
+
|
| 907 |
+
if file_desc is None:
|
| 908 |
+
file_name = None
|
| 909 |
+
else:
|
| 910 |
+
file_name = file_desc.name
|
| 911 |
+
|
| 912 |
+
if scope is None:
|
| 913 |
+
scope = {}
|
| 914 |
+
|
| 915 |
+
nested = [
|
| 916 |
+
self._ConvertMessageDescriptor(
|
| 917 |
+
nested, desc_name, file_desc, scope, syntax)
|
| 918 |
+
for nested in desc_proto.nested_type]
|
| 919 |
+
enums = [
|
| 920 |
+
self._ConvertEnumDescriptor(enum, desc_name, file_desc, None,
|
| 921 |
+
scope, False)
|
| 922 |
+
for enum in desc_proto.enum_type]
|
| 923 |
+
fields = [self._MakeFieldDescriptor(field, desc_name, index, file_desc)
|
| 924 |
+
for index, field in enumerate(desc_proto.field)]
|
| 925 |
+
extensions = [
|
| 926 |
+
self._MakeFieldDescriptor(extension, desc_name, index, file_desc,
|
| 927 |
+
is_extension=True)
|
| 928 |
+
for index, extension in enumerate(desc_proto.extension)]
|
| 929 |
+
oneofs = [
|
| 930 |
+
# pylint: disable=g-complex-comprehension
|
| 931 |
+
descriptor.OneofDescriptor(
|
| 932 |
+
desc.name,
|
| 933 |
+
'.'.join((desc_name, desc.name)),
|
| 934 |
+
index,
|
| 935 |
+
None,
|
| 936 |
+
[],
|
| 937 |
+
_OptionsOrNone(desc),
|
| 938 |
+
# pylint: disable=protected-access
|
| 939 |
+
create_key=descriptor._internal_create_key)
|
| 940 |
+
for index, desc in enumerate(desc_proto.oneof_decl)
|
| 941 |
+
]
|
| 942 |
+
extension_ranges = [(r.start, r.end) for r in desc_proto.extension_range]
|
| 943 |
+
if extension_ranges:
|
| 944 |
+
is_extendable = True
|
| 945 |
+
else:
|
| 946 |
+
is_extendable = False
|
| 947 |
+
desc = descriptor.Descriptor(
|
| 948 |
+
name=desc_proto.name,
|
| 949 |
+
full_name=desc_name,
|
| 950 |
+
filename=file_name,
|
| 951 |
+
containing_type=None,
|
| 952 |
+
fields=fields,
|
| 953 |
+
oneofs=oneofs,
|
| 954 |
+
nested_types=nested,
|
| 955 |
+
enum_types=enums,
|
| 956 |
+
extensions=extensions,
|
| 957 |
+
options=_OptionsOrNone(desc_proto),
|
| 958 |
+
is_extendable=is_extendable,
|
| 959 |
+
extension_ranges=extension_ranges,
|
| 960 |
+
file=file_desc,
|
| 961 |
+
serialized_start=None,
|
| 962 |
+
serialized_end=None,
|
| 963 |
+
is_map_entry=desc_proto.options.map_entry,
|
| 964 |
+
# pylint: disable=protected-access
|
| 965 |
+
create_key=descriptor._internal_create_key,
|
| 966 |
+
)
|
| 967 |
+
for nested in desc.nested_types:
|
| 968 |
+
nested.containing_type = desc
|
| 969 |
+
for enum in desc.enum_types:
|
| 970 |
+
enum.containing_type = desc
|
| 971 |
+
for field_index, field_desc in enumerate(desc_proto.field):
|
| 972 |
+
if field_desc.HasField('oneof_index'):
|
| 973 |
+
oneof_index = field_desc.oneof_index
|
| 974 |
+
oneofs[oneof_index].fields.append(fields[field_index])
|
| 975 |
+
fields[field_index].containing_oneof = oneofs[oneof_index]
|
| 976 |
+
|
| 977 |
+
scope[_PrefixWithDot(desc_name)] = desc
|
| 978 |
+
self._CheckConflictRegister(desc, desc.full_name, desc.file.name)
|
| 979 |
+
self._descriptors[desc_name] = desc
|
| 980 |
+
return desc
|
| 981 |
+
|
| 982 |
+
def _ConvertEnumDescriptor(self, enum_proto, package=None, file_desc=None,
|
| 983 |
+
containing_type=None, scope=None, top_level=False):
|
| 984 |
+
"""Make a protobuf EnumDescriptor given an EnumDescriptorProto protobuf.
|
| 985 |
+
|
| 986 |
+
Args:
|
| 987 |
+
enum_proto: The descriptor_pb2.EnumDescriptorProto protobuf message.
|
| 988 |
+
package: Optional package name for the new message EnumDescriptor.
|
| 989 |
+
file_desc: The file containing the enum descriptor.
|
| 990 |
+
containing_type: The type containing this enum.
|
| 991 |
+
scope: Scope containing available types.
|
| 992 |
+
top_level: If True, the enum is a top level symbol. If False, the enum
|
| 993 |
+
is defined inside a message.
|
| 994 |
+
|
| 995 |
+
Returns:
|
| 996 |
+
The added descriptor
|
| 997 |
+
"""
|
| 998 |
+
|
| 999 |
+
if package:
|
| 1000 |
+
enum_name = '.'.join((package, enum_proto.name))
|
| 1001 |
+
else:
|
| 1002 |
+
enum_name = enum_proto.name
|
| 1003 |
+
|
| 1004 |
+
if file_desc is None:
|
| 1005 |
+
file_name = None
|
| 1006 |
+
else:
|
| 1007 |
+
file_name = file_desc.name
|
| 1008 |
+
|
| 1009 |
+
values = [self._MakeEnumValueDescriptor(value, index)
|
| 1010 |
+
for index, value in enumerate(enum_proto.value)]
|
| 1011 |
+
desc = descriptor.EnumDescriptor(name=enum_proto.name,
|
| 1012 |
+
full_name=enum_name,
|
| 1013 |
+
filename=file_name,
|
| 1014 |
+
file=file_desc,
|
| 1015 |
+
values=values,
|
| 1016 |
+
containing_type=containing_type,
|
| 1017 |
+
options=_OptionsOrNone(enum_proto),
|
| 1018 |
+
# pylint: disable=protected-access
|
| 1019 |
+
create_key=descriptor._internal_create_key)
|
| 1020 |
+
scope['.%s' % enum_name] = desc
|
| 1021 |
+
self._CheckConflictRegister(desc, desc.full_name, desc.file.name)
|
| 1022 |
+
self._enum_descriptors[enum_name] = desc
|
| 1023 |
+
|
| 1024 |
+
# Add top level enum values.
|
| 1025 |
+
if top_level:
|
| 1026 |
+
for value in values:
|
| 1027 |
+
full_name = _NormalizeFullyQualifiedName(
|
| 1028 |
+
'.'.join((package, value.name)))
|
| 1029 |
+
self._CheckConflictRegister(value, full_name, file_name)
|
| 1030 |
+
self._top_enum_values[full_name] = value
|
| 1031 |
+
|
| 1032 |
+
return desc
|
| 1033 |
+
|
| 1034 |
+
def _MakeFieldDescriptor(self, field_proto, message_name, index,
|
| 1035 |
+
file_desc, is_extension=False):
|
| 1036 |
+
"""Creates a field descriptor from a FieldDescriptorProto.
|
| 1037 |
+
|
| 1038 |
+
For message and enum type fields, this method will do a look up
|
| 1039 |
+
in the pool for the appropriate descriptor for that type. If it
|
| 1040 |
+
is unavailable, it will fall back to the _source function to
|
| 1041 |
+
create it. If this type is still unavailable, construction will
|
| 1042 |
+
fail.
|
| 1043 |
+
|
| 1044 |
+
Args:
|
| 1045 |
+
field_proto: The proto describing the field.
|
| 1046 |
+
message_name: The name of the containing message.
|
| 1047 |
+
index: Index of the field
|
| 1048 |
+
file_desc: The file containing the field descriptor.
|
| 1049 |
+
is_extension: Indication that this field is for an extension.
|
| 1050 |
+
|
| 1051 |
+
Returns:
|
| 1052 |
+
An initialized FieldDescriptor object
|
| 1053 |
+
"""
|
| 1054 |
+
|
| 1055 |
+
if message_name:
|
| 1056 |
+
full_name = '.'.join((message_name, field_proto.name))
|
| 1057 |
+
else:
|
| 1058 |
+
full_name = field_proto.name
|
| 1059 |
+
|
| 1060 |
+
if field_proto.json_name:
|
| 1061 |
+
json_name = field_proto.json_name
|
| 1062 |
+
else:
|
| 1063 |
+
json_name = None
|
| 1064 |
+
|
| 1065 |
+
return descriptor.FieldDescriptor(
|
| 1066 |
+
name=field_proto.name,
|
| 1067 |
+
full_name=full_name,
|
| 1068 |
+
index=index,
|
| 1069 |
+
number=field_proto.number,
|
| 1070 |
+
type=field_proto.type,
|
| 1071 |
+
cpp_type=None,
|
| 1072 |
+
message_type=None,
|
| 1073 |
+
enum_type=None,
|
| 1074 |
+
containing_type=None,
|
| 1075 |
+
label=field_proto.label,
|
| 1076 |
+
has_default_value=False,
|
| 1077 |
+
default_value=None,
|
| 1078 |
+
is_extension=is_extension,
|
| 1079 |
+
extension_scope=None,
|
| 1080 |
+
options=_OptionsOrNone(field_proto),
|
| 1081 |
+
json_name=json_name,
|
| 1082 |
+
file=file_desc,
|
| 1083 |
+
# pylint: disable=protected-access
|
| 1084 |
+
create_key=descriptor._internal_create_key)
|
| 1085 |
+
|
| 1086 |
+
def _SetAllFieldTypes(self, package, desc_proto, scope):
|
| 1087 |
+
"""Sets all the descriptor's fields's types.
|
| 1088 |
+
|
| 1089 |
+
This method also sets the containing types on any extensions.
|
| 1090 |
+
|
| 1091 |
+
Args:
|
| 1092 |
+
package: The current package of desc_proto.
|
| 1093 |
+
desc_proto: The message descriptor to update.
|
| 1094 |
+
scope: Enclosing scope of available types.
|
| 1095 |
+
"""
|
| 1096 |
+
|
| 1097 |
+
package = _PrefixWithDot(package)
|
| 1098 |
+
|
| 1099 |
+
main_desc = self._GetTypeFromScope(package, desc_proto.name, scope)
|
| 1100 |
+
|
| 1101 |
+
if package == '.':
|
| 1102 |
+
nested_package = _PrefixWithDot(desc_proto.name)
|
| 1103 |
+
else:
|
| 1104 |
+
nested_package = '.'.join([package, desc_proto.name])
|
| 1105 |
+
|
| 1106 |
+
for field_proto, field_desc in zip(desc_proto.field, main_desc.fields):
|
| 1107 |
+
self._SetFieldType(field_proto, field_desc, nested_package, scope)
|
| 1108 |
+
|
| 1109 |
+
for extension_proto, extension_desc in (
|
| 1110 |
+
zip(desc_proto.extension, main_desc.extensions)):
|
| 1111 |
+
extension_desc.containing_type = self._GetTypeFromScope(
|
| 1112 |
+
nested_package, extension_proto.extendee, scope)
|
| 1113 |
+
self._SetFieldType(extension_proto, extension_desc, nested_package, scope)
|
| 1114 |
+
|
| 1115 |
+
for nested_type in desc_proto.nested_type:
|
| 1116 |
+
self._SetAllFieldTypes(nested_package, nested_type, scope)
|
| 1117 |
+
|
| 1118 |
+
def _SetFieldType(self, field_proto, field_desc, package, scope):
|
| 1119 |
+
"""Sets the field's type, cpp_type, message_type and enum_type.
|
| 1120 |
+
|
| 1121 |
+
Args:
|
| 1122 |
+
field_proto: Data about the field in proto format.
|
| 1123 |
+
field_desc: The descriptor to modify.
|
| 1124 |
+
package: The package the field's container is in.
|
| 1125 |
+
scope: Enclosing scope of available types.
|
| 1126 |
+
"""
|
| 1127 |
+
if field_proto.type_name:
|
| 1128 |
+
desc = self._GetTypeFromScope(package, field_proto.type_name, scope)
|
| 1129 |
+
else:
|
| 1130 |
+
desc = None
|
| 1131 |
+
|
| 1132 |
+
if not field_proto.HasField('type'):
|
| 1133 |
+
if isinstance(desc, descriptor.Descriptor):
|
| 1134 |
+
field_proto.type = descriptor.FieldDescriptor.TYPE_MESSAGE
|
| 1135 |
+
else:
|
| 1136 |
+
field_proto.type = descriptor.FieldDescriptor.TYPE_ENUM
|
| 1137 |
+
|
| 1138 |
+
field_desc.cpp_type = descriptor.FieldDescriptor.ProtoTypeToCppProtoType(
|
| 1139 |
+
field_proto.type)
|
| 1140 |
+
|
| 1141 |
+
if (field_proto.type == descriptor.FieldDescriptor.TYPE_MESSAGE
|
| 1142 |
+
or field_proto.type == descriptor.FieldDescriptor.TYPE_GROUP):
|
| 1143 |
+
field_desc.message_type = desc
|
| 1144 |
+
|
| 1145 |
+
if field_proto.type == descriptor.FieldDescriptor.TYPE_ENUM:
|
| 1146 |
+
field_desc.enum_type = desc
|
| 1147 |
+
|
| 1148 |
+
if field_proto.label == descriptor.FieldDescriptor.LABEL_REPEATED:
|
| 1149 |
+
field_desc.has_default_value = False
|
| 1150 |
+
field_desc.default_value = []
|
| 1151 |
+
elif field_proto.HasField('default_value'):
|
| 1152 |
+
field_desc.has_default_value = True
|
| 1153 |
+
if (field_proto.type == descriptor.FieldDescriptor.TYPE_DOUBLE or
|
| 1154 |
+
field_proto.type == descriptor.FieldDescriptor.TYPE_FLOAT):
|
| 1155 |
+
field_desc.default_value = float(field_proto.default_value)
|
| 1156 |
+
elif field_proto.type == descriptor.FieldDescriptor.TYPE_STRING:
|
| 1157 |
+
field_desc.default_value = field_proto.default_value
|
| 1158 |
+
elif field_proto.type == descriptor.FieldDescriptor.TYPE_BOOL:
|
| 1159 |
+
field_desc.default_value = field_proto.default_value.lower() == 'true'
|
| 1160 |
+
elif field_proto.type == descriptor.FieldDescriptor.TYPE_ENUM:
|
| 1161 |
+
field_desc.default_value = field_desc.enum_type.values_by_name[
|
| 1162 |
+
field_proto.default_value].number
|
| 1163 |
+
elif field_proto.type == descriptor.FieldDescriptor.TYPE_BYTES:
|
| 1164 |
+
field_desc.default_value = text_encoding.CUnescape(
|
| 1165 |
+
field_proto.default_value)
|
| 1166 |
+
elif field_proto.type == descriptor.FieldDescriptor.TYPE_MESSAGE:
|
| 1167 |
+
field_desc.default_value = None
|
| 1168 |
+
else:
|
| 1169 |
+
# All other types are of the "int" type.
|
| 1170 |
+
field_desc.default_value = int(field_proto.default_value)
|
| 1171 |
+
else:
|
| 1172 |
+
field_desc.has_default_value = False
|
| 1173 |
+
if (field_proto.type == descriptor.FieldDescriptor.TYPE_DOUBLE or
|
| 1174 |
+
field_proto.type == descriptor.FieldDescriptor.TYPE_FLOAT):
|
| 1175 |
+
field_desc.default_value = 0.0
|
| 1176 |
+
elif field_proto.type == descriptor.FieldDescriptor.TYPE_STRING:
|
| 1177 |
+
field_desc.default_value = u''
|
| 1178 |
+
elif field_proto.type == descriptor.FieldDescriptor.TYPE_BOOL:
|
| 1179 |
+
field_desc.default_value = False
|
| 1180 |
+
elif field_proto.type == descriptor.FieldDescriptor.TYPE_ENUM:
|
| 1181 |
+
field_desc.default_value = field_desc.enum_type.values[0].number
|
| 1182 |
+
elif field_proto.type == descriptor.FieldDescriptor.TYPE_BYTES:
|
| 1183 |
+
field_desc.default_value = b''
|
| 1184 |
+
elif field_proto.type == descriptor.FieldDescriptor.TYPE_MESSAGE:
|
| 1185 |
+
field_desc.default_value = None
|
| 1186 |
+
elif field_proto.type == descriptor.FieldDescriptor.TYPE_GROUP:
|
| 1187 |
+
field_desc.default_value = None
|
| 1188 |
+
else:
|
| 1189 |
+
# All other types are of the "int" type.
|
| 1190 |
+
field_desc.default_value = 0
|
| 1191 |
+
|
| 1192 |
+
field_desc.type = field_proto.type
|
| 1193 |
+
|
| 1194 |
+
def _MakeEnumValueDescriptor(self, value_proto, index):
|
| 1195 |
+
"""Creates a enum value descriptor object from a enum value proto.
|
| 1196 |
+
|
| 1197 |
+
Args:
|
| 1198 |
+
value_proto: The proto describing the enum value.
|
| 1199 |
+
index: The index of the enum value.
|
| 1200 |
+
|
| 1201 |
+
Returns:
|
| 1202 |
+
An initialized EnumValueDescriptor object.
|
| 1203 |
+
"""
|
| 1204 |
+
|
| 1205 |
+
return descriptor.EnumValueDescriptor(
|
| 1206 |
+
name=value_proto.name,
|
| 1207 |
+
index=index,
|
| 1208 |
+
number=value_proto.number,
|
| 1209 |
+
options=_OptionsOrNone(value_proto),
|
| 1210 |
+
type=None,
|
| 1211 |
+
# pylint: disable=protected-access
|
| 1212 |
+
create_key=descriptor._internal_create_key)
|
| 1213 |
+
|
| 1214 |
+
def _MakeServiceDescriptor(self, service_proto, service_index, scope,
|
| 1215 |
+
package, file_desc):
|
| 1216 |
+
"""Make a protobuf ServiceDescriptor given a ServiceDescriptorProto.
|
| 1217 |
+
|
| 1218 |
+
Args:
|
| 1219 |
+
service_proto: The descriptor_pb2.ServiceDescriptorProto protobuf message.
|
| 1220 |
+
service_index: The index of the service in the File.
|
| 1221 |
+
scope: Dict mapping short and full symbols to message and enum types.
|
| 1222 |
+
package: Optional package name for the new message EnumDescriptor.
|
| 1223 |
+
file_desc: The file containing the service descriptor.
|
| 1224 |
+
|
| 1225 |
+
Returns:
|
| 1226 |
+
The added descriptor.
|
| 1227 |
+
"""
|
| 1228 |
+
|
| 1229 |
+
if package:
|
| 1230 |
+
service_name = '.'.join((package, service_proto.name))
|
| 1231 |
+
else:
|
| 1232 |
+
service_name = service_proto.name
|
| 1233 |
+
|
| 1234 |
+
methods = [self._MakeMethodDescriptor(method_proto, service_name, package,
|
| 1235 |
+
scope, index)
|
| 1236 |
+
for index, method_proto in enumerate(service_proto.method)]
|
| 1237 |
+
desc = descriptor.ServiceDescriptor(
|
| 1238 |
+
name=service_proto.name,
|
| 1239 |
+
full_name=service_name,
|
| 1240 |
+
index=service_index,
|
| 1241 |
+
methods=methods,
|
| 1242 |
+
options=_OptionsOrNone(service_proto),
|
| 1243 |
+
file=file_desc,
|
| 1244 |
+
# pylint: disable=protected-access
|
| 1245 |
+
create_key=descriptor._internal_create_key)
|
| 1246 |
+
self._CheckConflictRegister(desc, desc.full_name, desc.file.name)
|
| 1247 |
+
self._service_descriptors[service_name] = desc
|
| 1248 |
+
return desc
|
| 1249 |
+
|
| 1250 |
+
def _MakeMethodDescriptor(self, method_proto, service_name, package, scope,
|
| 1251 |
+
index):
|
| 1252 |
+
"""Creates a method descriptor from a MethodDescriptorProto.
|
| 1253 |
+
|
| 1254 |
+
Args:
|
| 1255 |
+
method_proto: The proto describing the method.
|
| 1256 |
+
service_name: The name of the containing service.
|
| 1257 |
+
package: Optional package name to look up for types.
|
| 1258 |
+
scope: Scope containing available types.
|
| 1259 |
+
index: Index of the method in the service.
|
| 1260 |
+
|
| 1261 |
+
Returns:
|
| 1262 |
+
An initialized MethodDescriptor object.
|
| 1263 |
+
"""
|
| 1264 |
+
full_name = '.'.join((service_name, method_proto.name))
|
| 1265 |
+
input_type = self._GetTypeFromScope(
|
| 1266 |
+
package, method_proto.input_type, scope)
|
| 1267 |
+
output_type = self._GetTypeFromScope(
|
| 1268 |
+
package, method_proto.output_type, scope)
|
| 1269 |
+
return descriptor.MethodDescriptor(
|
| 1270 |
+
name=method_proto.name,
|
| 1271 |
+
full_name=full_name,
|
| 1272 |
+
index=index,
|
| 1273 |
+
containing_service=None,
|
| 1274 |
+
input_type=input_type,
|
| 1275 |
+
output_type=output_type,
|
| 1276 |
+
client_streaming=method_proto.client_streaming,
|
| 1277 |
+
server_streaming=method_proto.server_streaming,
|
| 1278 |
+
options=_OptionsOrNone(method_proto),
|
| 1279 |
+
# pylint: disable=protected-access
|
| 1280 |
+
create_key=descriptor._internal_create_key)
|
| 1281 |
+
|
| 1282 |
+
def _ExtractSymbols(self, descriptors):
|
| 1283 |
+
"""Pulls out all the symbols from descriptor protos.
|
| 1284 |
+
|
| 1285 |
+
Args:
|
| 1286 |
+
descriptors: The messages to extract descriptors from.
|
| 1287 |
+
Yields:
|
| 1288 |
+
A two element tuple of the type name and descriptor object.
|
| 1289 |
+
"""
|
| 1290 |
+
|
| 1291 |
+
for desc in descriptors:
|
| 1292 |
+
yield (_PrefixWithDot(desc.full_name), desc)
|
| 1293 |
+
for symbol in self._ExtractSymbols(desc.nested_types):
|
| 1294 |
+
yield symbol
|
| 1295 |
+
for enum in desc.enum_types:
|
| 1296 |
+
yield (_PrefixWithDot(enum.full_name), enum)
|
| 1297 |
+
|
| 1298 |
+
def _GetDeps(self, dependencies, visited=None):
|
| 1299 |
+
"""Recursively finds dependencies for file protos.
|
| 1300 |
+
|
| 1301 |
+
Args:
|
| 1302 |
+
dependencies: The names of the files being depended on.
|
| 1303 |
+
visited: The names of files already found.
|
| 1304 |
+
|
| 1305 |
+
Yields:
|
| 1306 |
+
Each direct and indirect dependency.
|
| 1307 |
+
"""
|
| 1308 |
+
|
| 1309 |
+
visited = visited or set()
|
| 1310 |
+
for dependency in dependencies:
|
| 1311 |
+
if dependency not in visited:
|
| 1312 |
+
visited.add(dependency)
|
| 1313 |
+
dep_desc = self.FindFileByName(dependency)
|
| 1314 |
+
yield dep_desc
|
| 1315 |
+
public_files = [d.name for d in dep_desc.public_dependencies]
|
| 1316 |
+
yield from self._GetDeps(public_files, visited)
|
| 1317 |
+
|
| 1318 |
+
def _GetTypeFromScope(self, package, type_name, scope):
|
| 1319 |
+
"""Finds a given type name in the current scope.
|
| 1320 |
+
|
| 1321 |
+
Args:
|
| 1322 |
+
package: The package the proto should be located in.
|
| 1323 |
+
type_name: The name of the type to be found in the scope.
|
| 1324 |
+
scope: Dict mapping short and full symbols to message and enum types.
|
| 1325 |
+
|
| 1326 |
+
Returns:
|
| 1327 |
+
The descriptor for the requested type.
|
| 1328 |
+
"""
|
| 1329 |
+
if type_name not in scope:
|
| 1330 |
+
components = _PrefixWithDot(package).split('.')
|
| 1331 |
+
while components:
|
| 1332 |
+
possible_match = '.'.join(components + [type_name])
|
| 1333 |
+
if possible_match in scope:
|
| 1334 |
+
type_name = possible_match
|
| 1335 |
+
break
|
| 1336 |
+
else:
|
| 1337 |
+
components.pop(-1)
|
| 1338 |
+
return scope[type_name]
|
| 1339 |
+
|
| 1340 |
+
|
| 1341 |
+
def _PrefixWithDot(name):
|
| 1342 |
+
return name if name.startswith('.') else '.%s' % name
|
| 1343 |
+
|
| 1344 |
+
|
| 1345 |
+
if _USE_C_DESCRIPTORS:
|
| 1346 |
+
# TODO: This pool could be constructed from Python code, when we
|
| 1347 |
+
# support a flag like 'use_cpp_generated_pool=True'.
|
| 1348 |
+
# pylint: disable=protected-access
|
| 1349 |
+
_DEFAULT = descriptor._message.default_pool
|
| 1350 |
+
else:
|
| 1351 |
+
_DEFAULT = DescriptorPool()
|
| 1352 |
+
|
| 1353 |
+
|
| 1354 |
+
def Default():
|
| 1355 |
+
return _DEFAULT
|
.venv/lib/python3.11/site-packages/google/protobuf/duration.py
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Protocol Buffers - Google's data interchange format
|
| 2 |
+
# Copyright 2008 Google Inc. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Use of this source code is governed by a BSD-style
|
| 5 |
+
# license that can be found in the LICENSE file or at
|
| 6 |
+
# https://developers.google.com/open-source/licenses/bsd
|
| 7 |
+
|
| 8 |
+
"""Contains the Duration helper APIs."""
|
| 9 |
+
|
| 10 |
+
import datetime
|
| 11 |
+
|
| 12 |
+
from google.protobuf.duration_pb2 import Duration
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def from_json_string(value: str) -> Duration:
|
| 16 |
+
"""Converts a string to Duration.
|
| 17 |
+
|
| 18 |
+
Args:
|
| 19 |
+
value: A string to be converted. The string must end with 's'. Any
|
| 20 |
+
fractional digits (or none) are accepted as long as they fit into
|
| 21 |
+
precision. For example: "1s", "1.01s", "1.0000001s", "-3.100s"
|
| 22 |
+
|
| 23 |
+
Raises:
|
| 24 |
+
ValueError: On parsing problems.
|
| 25 |
+
"""
|
| 26 |
+
duration = Duration()
|
| 27 |
+
duration.FromJsonString(value)
|
| 28 |
+
return duration
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def from_microseconds(micros: float) -> Duration:
|
| 32 |
+
"""Converts microseconds to Duration."""
|
| 33 |
+
duration = Duration()
|
| 34 |
+
duration.FromMicroseconds(micros)
|
| 35 |
+
return duration
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def from_milliseconds(millis: float) -> Duration:
|
| 39 |
+
"""Converts milliseconds to Duration."""
|
| 40 |
+
duration = Duration()
|
| 41 |
+
duration.FromMilliseconds(millis)
|
| 42 |
+
return duration
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def from_nanoseconds(nanos: float) -> Duration:
|
| 46 |
+
"""Converts nanoseconds to Duration."""
|
| 47 |
+
duration = Duration()
|
| 48 |
+
duration.FromNanoseconds(nanos)
|
| 49 |
+
return duration
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def from_seconds(seconds: float) -> Duration:
|
| 53 |
+
"""Converts seconds to Duration."""
|
| 54 |
+
duration = Duration()
|
| 55 |
+
duration.FromSeconds(seconds)
|
| 56 |
+
return duration
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def from_timedelta(td: datetime.timedelta) -> Duration:
|
| 60 |
+
"""Converts timedelta to Duration."""
|
| 61 |
+
duration = Duration()
|
| 62 |
+
duration.FromTimedelta(td)
|
| 63 |
+
return duration
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def to_json_string(duration: Duration) -> str:
|
| 67 |
+
"""Converts Duration to string format.
|
| 68 |
+
|
| 69 |
+
Returns:
|
| 70 |
+
A string converted from self. The string format will contains
|
| 71 |
+
3, 6, or 9 fractional digits depending on the precision required to
|
| 72 |
+
represent the exact Duration value. For example: "1s", "1.010s",
|
| 73 |
+
"1.000000100s", "-3.100s"
|
| 74 |
+
"""
|
| 75 |
+
return duration.ToJsonString()
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def to_microseconds(duration: Duration) -> int:
|
| 79 |
+
"""Converts a Duration to microseconds."""
|
| 80 |
+
return duration.ToMicroseconds()
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def to_milliseconds(duration: Duration) -> int:
|
| 84 |
+
"""Converts a Duration to milliseconds."""
|
| 85 |
+
return duration.ToMilliseconds()
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def to_nanoseconds(duration: Duration) -> int:
|
| 89 |
+
"""Converts a Duration to nanoseconds."""
|
| 90 |
+
return duration.ToNanoseconds()
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def to_seconds(duration: Duration) -> int:
|
| 94 |
+
"""Converts a Duration to seconds."""
|
| 95 |
+
return duration.ToSeconds()
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def to_timedelta(duration: Duration) -> datetime.timedelta:
|
| 99 |
+
"""Converts Duration to timedelta."""
|
| 100 |
+
return duration.ToTimedelta()
|
.venv/lib/python3.11/site-packages/google/protobuf/duration_pb2.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
| 3 |
+
# NO CHECKED-IN PROTOBUF GENCODE
|
| 4 |
+
# source: google/protobuf/duration.proto
|
| 5 |
+
# Protobuf Python Version: 5.29.3
|
| 6 |
+
"""Generated protocol buffer code."""
|
| 7 |
+
from google.protobuf import descriptor as _descriptor
|
| 8 |
+
from google.protobuf import descriptor_pool as _descriptor_pool
|
| 9 |
+
from google.protobuf import runtime_version as _runtime_version
|
| 10 |
+
from google.protobuf import symbol_database as _symbol_database
|
| 11 |
+
from google.protobuf.internal import builder as _builder
|
| 12 |
+
_runtime_version.ValidateProtobufRuntimeVersion(
|
| 13 |
+
_runtime_version.Domain.PUBLIC,
|
| 14 |
+
5,
|
| 15 |
+
29,
|
| 16 |
+
3,
|
| 17 |
+
'',
|
| 18 |
+
'google/protobuf/duration.proto'
|
| 19 |
+
)
|
| 20 |
+
# @@protoc_insertion_point(imports)
|
| 21 |
+
|
| 22 |
+
_sym_db = _symbol_database.Default()
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1egoogle/protobuf/duration.proto\x12\x0fgoogle.protobuf\":\n\x08\x44uration\x12\x18\n\x07seconds\x18\x01 \x01(\x03R\x07seconds\x12\x14\n\x05nanos\x18\x02 \x01(\x05R\x05nanosB\x83\x01\n\x13\x63om.google.protobufB\rDurationProtoP\x01Z1google.golang.org/protobuf/types/known/durationpb\xf8\x01\x01\xa2\x02\x03GPB\xaa\x02\x1eGoogle.Protobuf.WellKnownTypesb\x06proto3')
|
| 28 |
+
|
| 29 |
+
_globals = globals()
|
| 30 |
+
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
| 31 |
+
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'google.protobuf.duration_pb2', _globals)
|
| 32 |
+
if not _descriptor._USE_C_DESCRIPTORS:
|
| 33 |
+
_globals['DESCRIPTOR']._loaded_options = None
|
| 34 |
+
_globals['DESCRIPTOR']._serialized_options = b'\n\023com.google.protobufB\rDurationProtoP\001Z1google.golang.org/protobuf/types/known/durationpb\370\001\001\242\002\003GPB\252\002\036Google.Protobuf.WellKnownTypes'
|
| 35 |
+
_globals['_DURATION']._serialized_start=51
|
| 36 |
+
_globals['_DURATION']._serialized_end=109
|
| 37 |
+
# @@protoc_insertion_point(module_scope)
|
.venv/lib/python3.11/site-packages/google/protobuf/empty_pb2.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
| 3 |
+
# NO CHECKED-IN PROTOBUF GENCODE
|
| 4 |
+
# source: google/protobuf/empty.proto
|
| 5 |
+
# Protobuf Python Version: 5.29.3
|
| 6 |
+
"""Generated protocol buffer code."""
|
| 7 |
+
from google.protobuf import descriptor as _descriptor
|
| 8 |
+
from google.protobuf import descriptor_pool as _descriptor_pool
|
| 9 |
+
from google.protobuf import runtime_version as _runtime_version
|
| 10 |
+
from google.protobuf import symbol_database as _symbol_database
|
| 11 |
+
from google.protobuf.internal import builder as _builder
|
| 12 |
+
_runtime_version.ValidateProtobufRuntimeVersion(
|
| 13 |
+
_runtime_version.Domain.PUBLIC,
|
| 14 |
+
5,
|
| 15 |
+
29,
|
| 16 |
+
3,
|
| 17 |
+
'',
|
| 18 |
+
'google/protobuf/empty.proto'
|
| 19 |
+
)
|
| 20 |
+
# @@protoc_insertion_point(imports)
|
| 21 |
+
|
| 22 |
+
_sym_db = _symbol_database.Default()
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1bgoogle/protobuf/empty.proto\x12\x0fgoogle.protobuf\"\x07\n\x05\x45mptyB}\n\x13\x63om.google.protobufB\nEmptyProtoP\x01Z.google.golang.org/protobuf/types/known/emptypb\xf8\x01\x01\xa2\x02\x03GPB\xaa\x02\x1eGoogle.Protobuf.WellKnownTypesb\x06proto3')
|
| 28 |
+
|
| 29 |
+
_globals = globals()
|
| 30 |
+
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
| 31 |
+
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'google.protobuf.empty_pb2', _globals)
|
| 32 |
+
if not _descriptor._USE_C_DESCRIPTORS:
|
| 33 |
+
_globals['DESCRIPTOR']._loaded_options = None
|
| 34 |
+
_globals['DESCRIPTOR']._serialized_options = b'\n\023com.google.protobufB\nEmptyProtoP\001Z.google.golang.org/protobuf/types/known/emptypb\370\001\001\242\002\003GPB\252\002\036Google.Protobuf.WellKnownTypes'
|
| 35 |
+
_globals['_EMPTY']._serialized_start=48
|
| 36 |
+
_globals['_EMPTY']._serialized_end=55
|
| 37 |
+
# @@protoc_insertion_point(module_scope)
|
.venv/lib/python3.11/site-packages/google/protobuf/field_mask_pb2.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
| 3 |
+
# NO CHECKED-IN PROTOBUF GENCODE
|
| 4 |
+
# source: google/protobuf/field_mask.proto
|
| 5 |
+
# Protobuf Python Version: 5.29.3
|
| 6 |
+
"""Generated protocol buffer code."""
|
| 7 |
+
from google.protobuf import descriptor as _descriptor
|
| 8 |
+
from google.protobuf import descriptor_pool as _descriptor_pool
|
| 9 |
+
from google.protobuf import runtime_version as _runtime_version
|
| 10 |
+
from google.protobuf import symbol_database as _symbol_database
|
| 11 |
+
from google.protobuf.internal import builder as _builder
|
| 12 |
+
_runtime_version.ValidateProtobufRuntimeVersion(
|
| 13 |
+
_runtime_version.Domain.PUBLIC,
|
| 14 |
+
5,
|
| 15 |
+
29,
|
| 16 |
+
3,
|
| 17 |
+
'',
|
| 18 |
+
'google/protobuf/field_mask.proto'
|
| 19 |
+
)
|
| 20 |
+
# @@protoc_insertion_point(imports)
|
| 21 |
+
|
| 22 |
+
_sym_db = _symbol_database.Default()
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n google/protobuf/field_mask.proto\x12\x0fgoogle.protobuf\"!\n\tFieldMask\x12\x14\n\x05paths\x18\x01 \x03(\tR\x05pathsB\x85\x01\n\x13\x63om.google.protobufB\x0e\x46ieldMaskProtoP\x01Z2google.golang.org/protobuf/types/known/fieldmaskpb\xf8\x01\x01\xa2\x02\x03GPB\xaa\x02\x1eGoogle.Protobuf.WellKnownTypesb\x06proto3')
|
| 28 |
+
|
| 29 |
+
_globals = globals()
|
| 30 |
+
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
| 31 |
+
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'google.protobuf.field_mask_pb2', _globals)
|
| 32 |
+
if not _descriptor._USE_C_DESCRIPTORS:
|
| 33 |
+
_globals['DESCRIPTOR']._loaded_options = None
|
| 34 |
+
_globals['DESCRIPTOR']._serialized_options = b'\n\023com.google.protobufB\016FieldMaskProtoP\001Z2google.golang.org/protobuf/types/known/fieldmaskpb\370\001\001\242\002\003GPB\252\002\036Google.Protobuf.WellKnownTypes'
|
| 35 |
+
_globals['_FIELDMASK']._serialized_start=53
|
| 36 |
+
_globals['_FIELDMASK']._serialized_end=86
|
| 37 |
+
# @@protoc_insertion_point(module_scope)
|
.venv/lib/python3.11/site-packages/google/protobuf/internal/_parameterized.py
ADDED
|
@@ -0,0 +1,420 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#! /usr/bin/env python
|
| 2 |
+
#
|
| 3 |
+
# Protocol Buffers - Google's data interchange format
|
| 4 |
+
# Copyright 2008 Google Inc. All rights reserved.
|
| 5 |
+
#
|
| 6 |
+
# Use of this source code is governed by a BSD-style
|
| 7 |
+
# license that can be found in the LICENSE file or at
|
| 8 |
+
# https://developers.google.com/open-source/licenses/bsd
|
| 9 |
+
|
| 10 |
+
"""Adds support for parameterized tests to Python's unittest TestCase class.
|
| 11 |
+
|
| 12 |
+
A parameterized test is a method in a test case that is invoked with different
|
| 13 |
+
argument tuples.
|
| 14 |
+
|
| 15 |
+
A simple example:
|
| 16 |
+
|
| 17 |
+
class AdditionExample(_parameterized.TestCase):
|
| 18 |
+
@_parameterized.parameters(
|
| 19 |
+
(1, 2, 3),
|
| 20 |
+
(4, 5, 9),
|
| 21 |
+
(1, 1, 3))
|
| 22 |
+
def testAddition(self, op1, op2, result):
|
| 23 |
+
self.assertEqual(result, op1 + op2)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
Each invocation is a separate test case and properly isolated just
|
| 27 |
+
like a normal test method, with its own setUp/tearDown cycle. In the
|
| 28 |
+
example above, there are three separate testcases, one of which will
|
| 29 |
+
fail due to an assertion error (1 + 1 != 3).
|
| 30 |
+
|
| 31 |
+
Parameters for individual test cases can be tuples (with positional parameters)
|
| 32 |
+
or dictionaries (with named parameters):
|
| 33 |
+
|
| 34 |
+
class AdditionExample(_parameterized.TestCase):
|
| 35 |
+
@_parameterized.parameters(
|
| 36 |
+
{'op1': 1, 'op2': 2, 'result': 3},
|
| 37 |
+
{'op1': 4, 'op2': 5, 'result': 9},
|
| 38 |
+
)
|
| 39 |
+
def testAddition(self, op1, op2, result):
|
| 40 |
+
self.assertEqual(result, op1 + op2)
|
| 41 |
+
|
| 42 |
+
If a parameterized test fails, the error message will show the
|
| 43 |
+
original test name (which is modified internally) and the arguments
|
| 44 |
+
for the specific invocation, which are part of the string returned by
|
| 45 |
+
the shortDescription() method on test cases.
|
| 46 |
+
|
| 47 |
+
The id method of the test, used internally by the unittest framework,
|
| 48 |
+
is also modified to show the arguments. To make sure that test names
|
| 49 |
+
stay the same across several invocations, object representations like
|
| 50 |
+
|
| 51 |
+
>>> class Foo(object):
|
| 52 |
+
... pass
|
| 53 |
+
>>> repr(Foo())
|
| 54 |
+
'<__main__.Foo object at 0x23d8610>'
|
| 55 |
+
|
| 56 |
+
are turned into '<__main__.Foo>'. For even more descriptive names,
|
| 57 |
+
especially in test logs, you can use the named_parameters decorator. In
|
| 58 |
+
this case, only tuples are supported, and the first parameters has to
|
| 59 |
+
be a string (or an object that returns an apt name when converted via
|
| 60 |
+
str()):
|
| 61 |
+
|
| 62 |
+
class NamedExample(_parameterized.TestCase):
|
| 63 |
+
@_parameterized.named_parameters(
|
| 64 |
+
('Normal', 'aa', 'aaa', True),
|
| 65 |
+
('EmptyPrefix', '', 'abc', True),
|
| 66 |
+
('BothEmpty', '', '', True))
|
| 67 |
+
def testStartsWith(self, prefix, string, result):
|
| 68 |
+
self.assertEqual(result, strings.startswith(prefix))
|
| 69 |
+
|
| 70 |
+
Named tests also have the benefit that they can be run individually
|
| 71 |
+
from the command line:
|
| 72 |
+
|
| 73 |
+
$ testmodule.py NamedExample.testStartsWithNormal
|
| 74 |
+
.
|
| 75 |
+
--------------------------------------------------------------------
|
| 76 |
+
Ran 1 test in 0.000s
|
| 77 |
+
|
| 78 |
+
OK
|
| 79 |
+
|
| 80 |
+
Parameterized Classes
|
| 81 |
+
=====================
|
| 82 |
+
If invocation arguments are shared across test methods in a single
|
| 83 |
+
TestCase class, instead of decorating all test methods
|
| 84 |
+
individually, the class itself can be decorated:
|
| 85 |
+
|
| 86 |
+
@_parameterized.parameters(
|
| 87 |
+
(1, 2, 3)
|
| 88 |
+
(4, 5, 9))
|
| 89 |
+
class ArithmeticTest(_parameterized.TestCase):
|
| 90 |
+
def testAdd(self, arg1, arg2, result):
|
| 91 |
+
self.assertEqual(arg1 + arg2, result)
|
| 92 |
+
|
| 93 |
+
def testSubtract(self, arg2, arg2, result):
|
| 94 |
+
self.assertEqual(result - arg1, arg2)
|
| 95 |
+
|
| 96 |
+
Inputs from Iterables
|
| 97 |
+
=====================
|
| 98 |
+
If parameters should be shared across several test cases, or are dynamically
|
| 99 |
+
created from other sources, a single non-tuple iterable can be passed into
|
| 100 |
+
the decorator. This iterable will be used to obtain the test cases:
|
| 101 |
+
|
| 102 |
+
class AdditionExample(_parameterized.TestCase):
|
| 103 |
+
@_parameterized.parameters(
|
| 104 |
+
c.op1, c.op2, c.result for c in testcases
|
| 105 |
+
)
|
| 106 |
+
def testAddition(self, op1, op2, result):
|
| 107 |
+
self.assertEqual(result, op1 + op2)
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
Single-Argument Test Methods
|
| 111 |
+
============================
|
| 112 |
+
If a test method takes only one argument, the single argument does not need to
|
| 113 |
+
be wrapped into a tuple:
|
| 114 |
+
|
| 115 |
+
class NegativeNumberExample(_parameterized.TestCase):
|
| 116 |
+
@_parameterized.parameters(
|
| 117 |
+
-1, -3, -4, -5
|
| 118 |
+
)
|
| 119 |
+
def testIsNegative(self, arg):
|
| 120 |
+
self.assertTrue(IsNegative(arg))
|
| 121 |
+
"""
|
| 122 |
+
|
| 123 |
+
__author__ = 'tmarek@google.com (Torsten Marek)'
|
| 124 |
+
|
| 125 |
+
import functools
|
| 126 |
+
import re
|
| 127 |
+
import types
|
| 128 |
+
import unittest
|
| 129 |
+
import uuid
|
| 130 |
+
|
| 131 |
+
try:
|
| 132 |
+
# Since python 3
|
| 133 |
+
import collections.abc as collections_abc
|
| 134 |
+
except ImportError:
|
| 135 |
+
# Won't work after python 3.8
|
| 136 |
+
import collections as collections_abc
|
| 137 |
+
|
| 138 |
+
ADDR_RE = re.compile(r'\<([a-zA-Z0-9_\-\.]+) object at 0x[a-fA-F0-9]+\>')
|
| 139 |
+
_SEPARATOR = uuid.uuid1().hex
|
| 140 |
+
_FIRST_ARG = object()
|
| 141 |
+
_ARGUMENT_REPR = object()
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def _CleanRepr(obj):
|
| 145 |
+
return ADDR_RE.sub(r'<\1>', repr(obj))
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
# Helper function formerly from the unittest module, removed from it in
|
| 149 |
+
# Python 2.7.
|
| 150 |
+
def _StrClass(cls):
|
| 151 |
+
return '%s.%s' % (cls.__module__, cls.__name__)
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
def _NonStringIterable(obj):
|
| 155 |
+
return (isinstance(obj, collections_abc.Iterable) and
|
| 156 |
+
not isinstance(obj, str))
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
def _FormatParameterList(testcase_params):
|
| 160 |
+
if isinstance(testcase_params, collections_abc.Mapping):
|
| 161 |
+
return ', '.join('%s=%s' % (argname, _CleanRepr(value))
|
| 162 |
+
for argname, value in testcase_params.items())
|
| 163 |
+
elif _NonStringIterable(testcase_params):
|
| 164 |
+
return ', '.join(map(_CleanRepr, testcase_params))
|
| 165 |
+
else:
|
| 166 |
+
return _FormatParameterList((testcase_params,))
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
class _ParameterizedTestIter(object):
|
| 170 |
+
"""Callable and iterable class for producing new test cases."""
|
| 171 |
+
|
| 172 |
+
def __init__(self, test_method, testcases, naming_type):
|
| 173 |
+
"""Returns concrete test functions for a test and a list of parameters.
|
| 174 |
+
|
| 175 |
+
The naming_type is used to determine the name of the concrete
|
| 176 |
+
functions as reported by the unittest framework. If naming_type is
|
| 177 |
+
_FIRST_ARG, the testcases must be tuples, and the first element must
|
| 178 |
+
have a string representation that is a valid Python identifier.
|
| 179 |
+
|
| 180 |
+
Args:
|
| 181 |
+
test_method: The decorated test method.
|
| 182 |
+
testcases: (list of tuple/dict) A list of parameter
|
| 183 |
+
tuples/dicts for individual test invocations.
|
| 184 |
+
naming_type: The test naming type, either _NAMED or _ARGUMENT_REPR.
|
| 185 |
+
"""
|
| 186 |
+
self._test_method = test_method
|
| 187 |
+
self.testcases = testcases
|
| 188 |
+
self._naming_type = naming_type
|
| 189 |
+
|
| 190 |
+
def __call__(self, *args, **kwargs):
|
| 191 |
+
raise RuntimeError('You appear to be running a parameterized test case '
|
| 192 |
+
'without having inherited from parameterized.'
|
| 193 |
+
'TestCase. This is bad because none of '
|
| 194 |
+
'your test cases are actually being run.')
|
| 195 |
+
|
| 196 |
+
def __iter__(self):
|
| 197 |
+
test_method = self._test_method
|
| 198 |
+
naming_type = self._naming_type
|
| 199 |
+
|
| 200 |
+
def MakeBoundParamTest(testcase_params):
|
| 201 |
+
@functools.wraps(test_method)
|
| 202 |
+
def BoundParamTest(self):
|
| 203 |
+
if isinstance(testcase_params, collections_abc.Mapping):
|
| 204 |
+
test_method(self, **testcase_params)
|
| 205 |
+
elif _NonStringIterable(testcase_params):
|
| 206 |
+
test_method(self, *testcase_params)
|
| 207 |
+
else:
|
| 208 |
+
test_method(self, testcase_params)
|
| 209 |
+
|
| 210 |
+
if naming_type is _FIRST_ARG:
|
| 211 |
+
# Signal the metaclass that the name of the test function is unique
|
| 212 |
+
# and descriptive.
|
| 213 |
+
BoundParamTest.__x_use_name__ = True
|
| 214 |
+
BoundParamTest.__name__ += str(testcase_params[0])
|
| 215 |
+
testcase_params = testcase_params[1:]
|
| 216 |
+
elif naming_type is _ARGUMENT_REPR:
|
| 217 |
+
# __x_extra_id__ is used to pass naming information to the __new__
|
| 218 |
+
# method of TestGeneratorMetaclass.
|
| 219 |
+
# The metaclass will make sure to create a unique, but nondescriptive
|
| 220 |
+
# name for this test.
|
| 221 |
+
BoundParamTest.__x_extra_id__ = '(%s)' % (
|
| 222 |
+
_FormatParameterList(testcase_params),)
|
| 223 |
+
else:
|
| 224 |
+
raise RuntimeError('%s is not a valid naming type.' % (naming_type,))
|
| 225 |
+
|
| 226 |
+
BoundParamTest.__doc__ = '%s(%s)' % (
|
| 227 |
+
BoundParamTest.__name__, _FormatParameterList(testcase_params))
|
| 228 |
+
if test_method.__doc__:
|
| 229 |
+
BoundParamTest.__doc__ += '\n%s' % (test_method.__doc__,)
|
| 230 |
+
return BoundParamTest
|
| 231 |
+
return (MakeBoundParamTest(c) for c in self.testcases)
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
def _IsSingletonList(testcases):
|
| 235 |
+
"""True iff testcases contains only a single non-tuple element."""
|
| 236 |
+
return len(testcases) == 1 and not isinstance(testcases[0], tuple)
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
def _ModifyClass(class_object, testcases, naming_type):
|
| 240 |
+
assert not getattr(class_object, '_id_suffix', None), (
|
| 241 |
+
'Cannot add parameters to %s,'
|
| 242 |
+
' which already has parameterized methods.' % (class_object,))
|
| 243 |
+
class_object._id_suffix = id_suffix = {}
|
| 244 |
+
# We change the size of __dict__ while we iterate over it,
|
| 245 |
+
# which Python 3.x will complain about, so use copy().
|
| 246 |
+
for name, obj in class_object.__dict__.copy().items():
|
| 247 |
+
if (name.startswith(unittest.TestLoader.testMethodPrefix)
|
| 248 |
+
and isinstance(obj, types.FunctionType)):
|
| 249 |
+
delattr(class_object, name)
|
| 250 |
+
methods = {}
|
| 251 |
+
_UpdateClassDictForParamTestCase(
|
| 252 |
+
methods, id_suffix, name,
|
| 253 |
+
_ParameterizedTestIter(obj, testcases, naming_type))
|
| 254 |
+
for name, meth in methods.items():
|
| 255 |
+
setattr(class_object, name, meth)
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
def _ParameterDecorator(naming_type, testcases):
|
| 259 |
+
"""Implementation of the parameterization decorators.
|
| 260 |
+
|
| 261 |
+
Args:
|
| 262 |
+
naming_type: The naming type.
|
| 263 |
+
testcases: Testcase parameters.
|
| 264 |
+
|
| 265 |
+
Returns:
|
| 266 |
+
A function for modifying the decorated object.
|
| 267 |
+
"""
|
| 268 |
+
def _Apply(obj):
|
| 269 |
+
if isinstance(obj, type):
|
| 270 |
+
_ModifyClass(
|
| 271 |
+
obj,
|
| 272 |
+
list(testcases) if not isinstance(testcases, collections_abc.Sequence)
|
| 273 |
+
else testcases,
|
| 274 |
+
naming_type)
|
| 275 |
+
return obj
|
| 276 |
+
else:
|
| 277 |
+
return _ParameterizedTestIter(obj, testcases, naming_type)
|
| 278 |
+
|
| 279 |
+
if _IsSingletonList(testcases):
|
| 280 |
+
assert _NonStringIterable(testcases[0]), (
|
| 281 |
+
'Single parameter argument must be a non-string iterable')
|
| 282 |
+
testcases = testcases[0]
|
| 283 |
+
|
| 284 |
+
return _Apply
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
def parameters(*testcases): # pylint: disable=invalid-name
|
| 288 |
+
"""A decorator for creating parameterized tests.
|
| 289 |
+
|
| 290 |
+
See the module docstring for a usage example.
|
| 291 |
+
Args:
|
| 292 |
+
*testcases: Parameters for the decorated method, either a single
|
| 293 |
+
iterable, or a list of tuples/dicts/objects (for tests
|
| 294 |
+
with only one argument).
|
| 295 |
+
|
| 296 |
+
Returns:
|
| 297 |
+
A test generator to be handled by TestGeneratorMetaclass.
|
| 298 |
+
"""
|
| 299 |
+
return _ParameterDecorator(_ARGUMENT_REPR, testcases)
|
| 300 |
+
|
| 301 |
+
|
| 302 |
+
def named_parameters(*testcases): # pylint: disable=invalid-name
|
| 303 |
+
"""A decorator for creating parameterized tests.
|
| 304 |
+
|
| 305 |
+
See the module docstring for a usage example. The first element of
|
| 306 |
+
each parameter tuple should be a string and will be appended to the
|
| 307 |
+
name of the test method.
|
| 308 |
+
|
| 309 |
+
Args:
|
| 310 |
+
*testcases: Parameters for the decorated method, either a single
|
| 311 |
+
iterable, or a list of tuples.
|
| 312 |
+
|
| 313 |
+
Returns:
|
| 314 |
+
A test generator to be handled by TestGeneratorMetaclass.
|
| 315 |
+
"""
|
| 316 |
+
return _ParameterDecorator(_FIRST_ARG, testcases)
|
| 317 |
+
|
| 318 |
+
|
| 319 |
+
class TestGeneratorMetaclass(type):
|
| 320 |
+
"""Metaclass for test cases with test generators.
|
| 321 |
+
|
| 322 |
+
A test generator is an iterable in a testcase that produces callables. These
|
| 323 |
+
callables must be single-argument methods. These methods are injected into
|
| 324 |
+
the class namespace and the original iterable is removed. If the name of the
|
| 325 |
+
iterable conforms to the test pattern, the injected methods will be picked
|
| 326 |
+
up as tests by the unittest framework.
|
| 327 |
+
|
| 328 |
+
In general, it is supposed to be used in conjunction with the
|
| 329 |
+
parameters decorator.
|
| 330 |
+
"""
|
| 331 |
+
|
| 332 |
+
def __new__(mcs, class_name, bases, dct):
|
| 333 |
+
dct['_id_suffix'] = id_suffix = {}
|
| 334 |
+
for name, obj in dct.copy().items():
|
| 335 |
+
if (name.startswith(unittest.TestLoader.testMethodPrefix) and
|
| 336 |
+
_NonStringIterable(obj)):
|
| 337 |
+
iterator = iter(obj)
|
| 338 |
+
dct.pop(name)
|
| 339 |
+
_UpdateClassDictForParamTestCase(dct, id_suffix, name, iterator)
|
| 340 |
+
|
| 341 |
+
return type.__new__(mcs, class_name, bases, dct)
|
| 342 |
+
|
| 343 |
+
|
| 344 |
+
def _UpdateClassDictForParamTestCase(dct, id_suffix, name, iterator):
|
| 345 |
+
"""Adds individual test cases to a dictionary.
|
| 346 |
+
|
| 347 |
+
Args:
|
| 348 |
+
dct: The target dictionary.
|
| 349 |
+
id_suffix: The dictionary for mapping names to test IDs.
|
| 350 |
+
name: The original name of the test case.
|
| 351 |
+
iterator: The iterator generating the individual test cases.
|
| 352 |
+
"""
|
| 353 |
+
for idx, func in enumerate(iterator):
|
| 354 |
+
assert callable(func), 'Test generators must yield callables, got %r' % (
|
| 355 |
+
func,)
|
| 356 |
+
if getattr(func, '__x_use_name__', False):
|
| 357 |
+
new_name = func.__name__
|
| 358 |
+
else:
|
| 359 |
+
new_name = '%s%s%d' % (name, _SEPARATOR, idx)
|
| 360 |
+
assert new_name not in dct, (
|
| 361 |
+
'Name of parameterized test case "%s" not unique' % (new_name,))
|
| 362 |
+
dct[new_name] = func
|
| 363 |
+
id_suffix[new_name] = getattr(func, '__x_extra_id__', '')
|
| 364 |
+
|
| 365 |
+
|
| 366 |
+
class TestCase(unittest.TestCase, metaclass=TestGeneratorMetaclass):
|
| 367 |
+
"""Base class for test cases using the parameters decorator."""
|
| 368 |
+
|
| 369 |
+
def _OriginalName(self):
|
| 370 |
+
return self._testMethodName.split(_SEPARATOR)[0]
|
| 371 |
+
|
| 372 |
+
def __str__(self):
|
| 373 |
+
return '%s (%s)' % (self._OriginalName(), _StrClass(self.__class__))
|
| 374 |
+
|
| 375 |
+
def id(self): # pylint: disable=invalid-name
|
| 376 |
+
"""Returns the descriptive ID of the test.
|
| 377 |
+
|
| 378 |
+
This is used internally by the unittesting framework to get a name
|
| 379 |
+
for the test to be used in reports.
|
| 380 |
+
|
| 381 |
+
Returns:
|
| 382 |
+
The test id.
|
| 383 |
+
"""
|
| 384 |
+
return '%s.%s%s' % (_StrClass(self.__class__),
|
| 385 |
+
self._OriginalName(),
|
| 386 |
+
self._id_suffix.get(self._testMethodName, ''))
|
| 387 |
+
|
| 388 |
+
|
| 389 |
+
def CoopTestCase(other_base_class):
|
| 390 |
+
"""Returns a new base class with a cooperative metaclass base.
|
| 391 |
+
|
| 392 |
+
This enables the TestCase to be used in combination
|
| 393 |
+
with other base classes that have custom metaclasses, such as
|
| 394 |
+
mox.MoxTestBase.
|
| 395 |
+
|
| 396 |
+
Only works with metaclasses that do not override type.__new__.
|
| 397 |
+
|
| 398 |
+
Example:
|
| 399 |
+
|
| 400 |
+
import google3
|
| 401 |
+
import mox
|
| 402 |
+
|
| 403 |
+
from google.protobuf.internal import _parameterized
|
| 404 |
+
|
| 405 |
+
class ExampleTest(parameterized.CoopTestCase(mox.MoxTestBase)):
|
| 406 |
+
...
|
| 407 |
+
|
| 408 |
+
Args:
|
| 409 |
+
other_base_class: (class) A test case base class.
|
| 410 |
+
|
| 411 |
+
Returns:
|
| 412 |
+
A new class object.
|
| 413 |
+
"""
|
| 414 |
+
metaclass = type(
|
| 415 |
+
'CoopMetaclass',
|
| 416 |
+
(other_base_class.__metaclass__,
|
| 417 |
+
TestGeneratorMetaclass), {})
|
| 418 |
+
return metaclass(
|
| 419 |
+
'CoopTestCase',
|
| 420 |
+
(other_base_class, TestCase), {})
|
.venv/lib/python3.11/site-packages/google/protobuf/internal/containers.py
ADDED
|
@@ -0,0 +1,677 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Protocol Buffers - Google's data interchange format
|
| 2 |
+
# Copyright 2008 Google Inc. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Use of this source code is governed by a BSD-style
|
| 5 |
+
# license that can be found in the LICENSE file or at
|
| 6 |
+
# https://developers.google.com/open-source/licenses/bsd
|
| 7 |
+
|
| 8 |
+
"""Contains container classes to represent different protocol buffer types.
|
| 9 |
+
|
| 10 |
+
This file defines container classes which represent categories of protocol
|
| 11 |
+
buffer field types which need extra maintenance. Currently these categories
|
| 12 |
+
are:
|
| 13 |
+
|
| 14 |
+
- Repeated scalar fields - These are all repeated fields which aren't
|
| 15 |
+
composite (e.g. they are of simple types like int32, string, etc).
|
| 16 |
+
- Repeated composite fields - Repeated fields which are composite. This
|
| 17 |
+
includes groups and nested messages.
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
import collections.abc
|
| 21 |
+
import copy
|
| 22 |
+
import pickle
|
| 23 |
+
from typing import (
|
| 24 |
+
Any,
|
| 25 |
+
Iterable,
|
| 26 |
+
Iterator,
|
| 27 |
+
List,
|
| 28 |
+
MutableMapping,
|
| 29 |
+
MutableSequence,
|
| 30 |
+
NoReturn,
|
| 31 |
+
Optional,
|
| 32 |
+
Sequence,
|
| 33 |
+
TypeVar,
|
| 34 |
+
Union,
|
| 35 |
+
overload,
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
_T = TypeVar('_T')
|
| 40 |
+
_K = TypeVar('_K')
|
| 41 |
+
_V = TypeVar('_V')
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class BaseContainer(Sequence[_T]):
|
| 45 |
+
"""Base container class."""
|
| 46 |
+
|
| 47 |
+
# Minimizes memory usage and disallows assignment to other attributes.
|
| 48 |
+
__slots__ = ['_message_listener', '_values']
|
| 49 |
+
|
| 50 |
+
def __init__(self, message_listener: Any) -> None:
|
| 51 |
+
"""
|
| 52 |
+
Args:
|
| 53 |
+
message_listener: A MessageListener implementation.
|
| 54 |
+
The RepeatedScalarFieldContainer will call this object's
|
| 55 |
+
Modified() method when it is modified.
|
| 56 |
+
"""
|
| 57 |
+
self._message_listener = message_listener
|
| 58 |
+
self._values = []
|
| 59 |
+
|
| 60 |
+
@overload
|
| 61 |
+
def __getitem__(self, key: int) -> _T:
|
| 62 |
+
...
|
| 63 |
+
|
| 64 |
+
@overload
|
| 65 |
+
def __getitem__(self, key: slice) -> List[_T]:
|
| 66 |
+
...
|
| 67 |
+
|
| 68 |
+
def __getitem__(self, key):
|
| 69 |
+
"""Retrieves item by the specified key."""
|
| 70 |
+
return self._values[key]
|
| 71 |
+
|
| 72 |
+
def __len__(self) -> int:
|
| 73 |
+
"""Returns the number of elements in the container."""
|
| 74 |
+
return len(self._values)
|
| 75 |
+
|
| 76 |
+
def __ne__(self, other: Any) -> bool:
|
| 77 |
+
"""Checks if another instance isn't equal to this one."""
|
| 78 |
+
# The concrete classes should define __eq__.
|
| 79 |
+
return not self == other
|
| 80 |
+
|
| 81 |
+
__hash__ = None
|
| 82 |
+
|
| 83 |
+
def __repr__(self) -> str:
|
| 84 |
+
return repr(self._values)
|
| 85 |
+
|
| 86 |
+
def sort(self, *args, **kwargs) -> None:
|
| 87 |
+
# Continue to support the old sort_function keyword argument.
|
| 88 |
+
# This is expected to be a rare occurrence, so use LBYL to avoid
|
| 89 |
+
# the overhead of actually catching KeyError.
|
| 90 |
+
if 'sort_function' in kwargs:
|
| 91 |
+
kwargs['cmp'] = kwargs.pop('sort_function')
|
| 92 |
+
self._values.sort(*args, **kwargs)
|
| 93 |
+
|
| 94 |
+
def reverse(self) -> None:
|
| 95 |
+
self._values.reverse()
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
# TODO: Remove this. BaseContainer does *not* conform to
|
| 99 |
+
# MutableSequence, only its subclasses do.
|
| 100 |
+
collections.abc.MutableSequence.register(BaseContainer)
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
class RepeatedScalarFieldContainer(BaseContainer[_T], MutableSequence[_T]):
|
| 104 |
+
"""Simple, type-checked, list-like container for holding repeated scalars."""
|
| 105 |
+
|
| 106 |
+
# Disallows assignment to other attributes.
|
| 107 |
+
__slots__ = ['_type_checker']
|
| 108 |
+
|
| 109 |
+
def __init__(
|
| 110 |
+
self,
|
| 111 |
+
message_listener: Any,
|
| 112 |
+
type_checker: Any,
|
| 113 |
+
) -> None:
|
| 114 |
+
"""Args:
|
| 115 |
+
|
| 116 |
+
message_listener: A MessageListener implementation. The
|
| 117 |
+
RepeatedScalarFieldContainer will call this object's Modified() method
|
| 118 |
+
when it is modified.
|
| 119 |
+
type_checker: A type_checkers.ValueChecker instance to run on elements
|
| 120 |
+
inserted into this container.
|
| 121 |
+
"""
|
| 122 |
+
super().__init__(message_listener)
|
| 123 |
+
self._type_checker = type_checker
|
| 124 |
+
|
| 125 |
+
def append(self, value: _T) -> None:
|
| 126 |
+
"""Appends an item to the list. Similar to list.append()."""
|
| 127 |
+
self._values.append(self._type_checker.CheckValue(value))
|
| 128 |
+
if not self._message_listener.dirty:
|
| 129 |
+
self._message_listener.Modified()
|
| 130 |
+
|
| 131 |
+
def insert(self, key: int, value: _T) -> None:
|
| 132 |
+
"""Inserts the item at the specified position. Similar to list.insert()."""
|
| 133 |
+
self._values.insert(key, self._type_checker.CheckValue(value))
|
| 134 |
+
if not self._message_listener.dirty:
|
| 135 |
+
self._message_listener.Modified()
|
| 136 |
+
|
| 137 |
+
def extend(self, elem_seq: Iterable[_T]) -> None:
|
| 138 |
+
"""Extends by appending the given iterable. Similar to list.extend()."""
|
| 139 |
+
elem_seq_iter = iter(elem_seq)
|
| 140 |
+
new_values = [self._type_checker.CheckValue(elem) for elem in elem_seq_iter]
|
| 141 |
+
if new_values:
|
| 142 |
+
self._values.extend(new_values)
|
| 143 |
+
self._message_listener.Modified()
|
| 144 |
+
|
| 145 |
+
def MergeFrom(
|
| 146 |
+
self,
|
| 147 |
+
other: Union['RepeatedScalarFieldContainer[_T]', Iterable[_T]],
|
| 148 |
+
) -> None:
|
| 149 |
+
"""Appends the contents of another repeated field of the same type to this
|
| 150 |
+
one. We do not check the types of the individual fields.
|
| 151 |
+
"""
|
| 152 |
+
self._values.extend(other)
|
| 153 |
+
self._message_listener.Modified()
|
| 154 |
+
|
| 155 |
+
def remove(self, elem: _T):
|
| 156 |
+
"""Removes an item from the list. Similar to list.remove()."""
|
| 157 |
+
self._values.remove(elem)
|
| 158 |
+
self._message_listener.Modified()
|
| 159 |
+
|
| 160 |
+
def pop(self, key: Optional[int] = -1) -> _T:
|
| 161 |
+
"""Removes and returns an item at a given index. Similar to list.pop()."""
|
| 162 |
+
value = self._values[key]
|
| 163 |
+
self.__delitem__(key)
|
| 164 |
+
return value
|
| 165 |
+
|
| 166 |
+
@overload
|
| 167 |
+
def __setitem__(self, key: int, value: _T) -> None:
|
| 168 |
+
...
|
| 169 |
+
|
| 170 |
+
@overload
|
| 171 |
+
def __setitem__(self, key: slice, value: Iterable[_T]) -> None:
|
| 172 |
+
...
|
| 173 |
+
|
| 174 |
+
def __setitem__(self, key, value) -> None:
|
| 175 |
+
"""Sets the item on the specified position."""
|
| 176 |
+
if isinstance(key, slice):
|
| 177 |
+
if key.step is not None:
|
| 178 |
+
raise ValueError('Extended slices not supported')
|
| 179 |
+
self._values[key] = map(self._type_checker.CheckValue, value)
|
| 180 |
+
self._message_listener.Modified()
|
| 181 |
+
else:
|
| 182 |
+
self._values[key] = self._type_checker.CheckValue(value)
|
| 183 |
+
self._message_listener.Modified()
|
| 184 |
+
|
| 185 |
+
def __delitem__(self, key: Union[int, slice]) -> None:
|
| 186 |
+
"""Deletes the item at the specified position."""
|
| 187 |
+
del self._values[key]
|
| 188 |
+
self._message_listener.Modified()
|
| 189 |
+
|
| 190 |
+
def __eq__(self, other: Any) -> bool:
|
| 191 |
+
"""Compares the current instance with another one."""
|
| 192 |
+
if self is other:
|
| 193 |
+
return True
|
| 194 |
+
# Special case for the same type which should be common and fast.
|
| 195 |
+
if isinstance(other, self.__class__):
|
| 196 |
+
return other._values == self._values
|
| 197 |
+
# We are presumably comparing against some other sequence type.
|
| 198 |
+
return other == self._values
|
| 199 |
+
|
| 200 |
+
def __deepcopy__(
|
| 201 |
+
self,
|
| 202 |
+
unused_memo: Any = None,
|
| 203 |
+
) -> 'RepeatedScalarFieldContainer[_T]':
|
| 204 |
+
clone = RepeatedScalarFieldContainer(
|
| 205 |
+
copy.deepcopy(self._message_listener), self._type_checker)
|
| 206 |
+
clone.MergeFrom(self)
|
| 207 |
+
return clone
|
| 208 |
+
|
| 209 |
+
def __reduce__(self, **kwargs) -> NoReturn:
|
| 210 |
+
raise pickle.PickleError(
|
| 211 |
+
"Can't pickle repeated scalar fields, convert to list first")
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
# TODO: Constrain T to be a subtype of Message.
|
| 215 |
+
class RepeatedCompositeFieldContainer(BaseContainer[_T], MutableSequence[_T]):
|
| 216 |
+
"""Simple, list-like container for holding repeated composite fields."""
|
| 217 |
+
|
| 218 |
+
# Disallows assignment to other attributes.
|
| 219 |
+
__slots__ = ['_message_descriptor']
|
| 220 |
+
|
| 221 |
+
def __init__(self, message_listener: Any, message_descriptor: Any) -> None:
|
| 222 |
+
"""
|
| 223 |
+
Note that we pass in a descriptor instead of the generated directly,
|
| 224 |
+
since at the time we construct a _RepeatedCompositeFieldContainer we
|
| 225 |
+
haven't yet necessarily initialized the type that will be contained in the
|
| 226 |
+
container.
|
| 227 |
+
|
| 228 |
+
Args:
|
| 229 |
+
message_listener: A MessageListener implementation.
|
| 230 |
+
The RepeatedCompositeFieldContainer will call this object's
|
| 231 |
+
Modified() method when it is modified.
|
| 232 |
+
message_descriptor: A Descriptor instance describing the protocol type
|
| 233 |
+
that should be present in this container. We'll use the
|
| 234 |
+
_concrete_class field of this descriptor when the client calls add().
|
| 235 |
+
"""
|
| 236 |
+
super().__init__(message_listener)
|
| 237 |
+
self._message_descriptor = message_descriptor
|
| 238 |
+
|
| 239 |
+
def add(self, **kwargs: Any) -> _T:
|
| 240 |
+
"""Adds a new element at the end of the list and returns it. Keyword
|
| 241 |
+
arguments may be used to initialize the element.
|
| 242 |
+
"""
|
| 243 |
+
new_element = self._message_descriptor._concrete_class(**kwargs)
|
| 244 |
+
new_element._SetListener(self._message_listener)
|
| 245 |
+
self._values.append(new_element)
|
| 246 |
+
if not self._message_listener.dirty:
|
| 247 |
+
self._message_listener.Modified()
|
| 248 |
+
return new_element
|
| 249 |
+
|
| 250 |
+
def append(self, value: _T) -> None:
|
| 251 |
+
"""Appends one element by copying the message."""
|
| 252 |
+
new_element = self._message_descriptor._concrete_class()
|
| 253 |
+
new_element._SetListener(self._message_listener)
|
| 254 |
+
new_element.CopyFrom(value)
|
| 255 |
+
self._values.append(new_element)
|
| 256 |
+
if not self._message_listener.dirty:
|
| 257 |
+
self._message_listener.Modified()
|
| 258 |
+
|
| 259 |
+
def insert(self, key: int, value: _T) -> None:
|
| 260 |
+
"""Inserts the item at the specified position by copying."""
|
| 261 |
+
new_element = self._message_descriptor._concrete_class()
|
| 262 |
+
new_element._SetListener(self._message_listener)
|
| 263 |
+
new_element.CopyFrom(value)
|
| 264 |
+
self._values.insert(key, new_element)
|
| 265 |
+
if not self._message_listener.dirty:
|
| 266 |
+
self._message_listener.Modified()
|
| 267 |
+
|
| 268 |
+
def extend(self, elem_seq: Iterable[_T]) -> None:
|
| 269 |
+
"""Extends by appending the given sequence of elements of the same type
|
| 270 |
+
|
| 271 |
+
as this one, copying each individual message.
|
| 272 |
+
"""
|
| 273 |
+
message_class = self._message_descriptor._concrete_class
|
| 274 |
+
listener = self._message_listener
|
| 275 |
+
values = self._values
|
| 276 |
+
for message in elem_seq:
|
| 277 |
+
new_element = message_class()
|
| 278 |
+
new_element._SetListener(listener)
|
| 279 |
+
new_element.MergeFrom(message)
|
| 280 |
+
values.append(new_element)
|
| 281 |
+
listener.Modified()
|
| 282 |
+
|
| 283 |
+
def MergeFrom(
|
| 284 |
+
self,
|
| 285 |
+
other: Union['RepeatedCompositeFieldContainer[_T]', Iterable[_T]],
|
| 286 |
+
) -> None:
|
| 287 |
+
"""Appends the contents of another repeated field of the same type to this
|
| 288 |
+
one, copying each individual message.
|
| 289 |
+
"""
|
| 290 |
+
self.extend(other)
|
| 291 |
+
|
| 292 |
+
def remove(self, elem: _T) -> None:
|
| 293 |
+
"""Removes an item from the list. Similar to list.remove()."""
|
| 294 |
+
self._values.remove(elem)
|
| 295 |
+
self._message_listener.Modified()
|
| 296 |
+
|
| 297 |
+
def pop(self, key: Optional[int] = -1) -> _T:
|
| 298 |
+
"""Removes and returns an item at a given index. Similar to list.pop()."""
|
| 299 |
+
value = self._values[key]
|
| 300 |
+
self.__delitem__(key)
|
| 301 |
+
return value
|
| 302 |
+
|
| 303 |
+
@overload
|
| 304 |
+
def __setitem__(self, key: int, value: _T) -> None:
|
| 305 |
+
...
|
| 306 |
+
|
| 307 |
+
@overload
|
| 308 |
+
def __setitem__(self, key: slice, value: Iterable[_T]) -> None:
|
| 309 |
+
...
|
| 310 |
+
|
| 311 |
+
def __setitem__(self, key, value):
|
| 312 |
+
# This method is implemented to make RepeatedCompositeFieldContainer
|
| 313 |
+
# structurally compatible with typing.MutableSequence. It is
|
| 314 |
+
# otherwise unsupported and will always raise an error.
|
| 315 |
+
raise TypeError(
|
| 316 |
+
f'{self.__class__.__name__} object does not support item assignment')
|
| 317 |
+
|
| 318 |
+
def __delitem__(self, key: Union[int, slice]) -> None:
|
| 319 |
+
"""Deletes the item at the specified position."""
|
| 320 |
+
del self._values[key]
|
| 321 |
+
self._message_listener.Modified()
|
| 322 |
+
|
| 323 |
+
def __eq__(self, other: Any) -> bool:
|
| 324 |
+
"""Compares the current instance with another one."""
|
| 325 |
+
if self is other:
|
| 326 |
+
return True
|
| 327 |
+
if not isinstance(other, self.__class__):
|
| 328 |
+
raise TypeError('Can only compare repeated composite fields against '
|
| 329 |
+
'other repeated composite fields.')
|
| 330 |
+
return self._values == other._values
|
| 331 |
+
|
| 332 |
+
|
| 333 |
+
class ScalarMap(MutableMapping[_K, _V]):
|
| 334 |
+
"""Simple, type-checked, dict-like container for holding repeated scalars."""
|
| 335 |
+
|
| 336 |
+
# Disallows assignment to other attributes.
|
| 337 |
+
__slots__ = ['_key_checker', '_value_checker', '_values', '_message_listener',
|
| 338 |
+
'_entry_descriptor']
|
| 339 |
+
|
| 340 |
+
def __init__(
|
| 341 |
+
self,
|
| 342 |
+
message_listener: Any,
|
| 343 |
+
key_checker: Any,
|
| 344 |
+
value_checker: Any,
|
| 345 |
+
entry_descriptor: Any,
|
| 346 |
+
) -> None:
|
| 347 |
+
"""
|
| 348 |
+
Args:
|
| 349 |
+
message_listener: A MessageListener implementation.
|
| 350 |
+
The ScalarMap will call this object's Modified() method when it
|
| 351 |
+
is modified.
|
| 352 |
+
key_checker: A type_checkers.ValueChecker instance to run on keys
|
| 353 |
+
inserted into this container.
|
| 354 |
+
value_checker: A type_checkers.ValueChecker instance to run on values
|
| 355 |
+
inserted into this container.
|
| 356 |
+
entry_descriptor: The MessageDescriptor of a map entry: key and value.
|
| 357 |
+
"""
|
| 358 |
+
self._message_listener = message_listener
|
| 359 |
+
self._key_checker = key_checker
|
| 360 |
+
self._value_checker = value_checker
|
| 361 |
+
self._entry_descriptor = entry_descriptor
|
| 362 |
+
self._values = {}
|
| 363 |
+
|
| 364 |
+
def __getitem__(self, key: _K) -> _V:
|
| 365 |
+
try:
|
| 366 |
+
return self._values[key]
|
| 367 |
+
except KeyError:
|
| 368 |
+
key = self._key_checker.CheckValue(key)
|
| 369 |
+
val = self._value_checker.DefaultValue()
|
| 370 |
+
self._values[key] = val
|
| 371 |
+
return val
|
| 372 |
+
|
| 373 |
+
def __contains__(self, item: _K) -> bool:
|
| 374 |
+
# We check the key's type to match the strong-typing flavor of the API.
|
| 375 |
+
# Also this makes it easier to match the behavior of the C++ implementation.
|
| 376 |
+
self._key_checker.CheckValue(item)
|
| 377 |
+
return item in self._values
|
| 378 |
+
|
| 379 |
+
@overload
|
| 380 |
+
def get(self, key: _K) -> Optional[_V]:
|
| 381 |
+
...
|
| 382 |
+
|
| 383 |
+
@overload
|
| 384 |
+
def get(self, key: _K, default: _T) -> Union[_V, _T]:
|
| 385 |
+
...
|
| 386 |
+
|
| 387 |
+
# We need to override this explicitly, because our defaultdict-like behavior
|
| 388 |
+
# will make the default implementation (from our base class) always insert
|
| 389 |
+
# the key.
|
| 390 |
+
def get(self, key, default=None):
|
| 391 |
+
if key in self:
|
| 392 |
+
return self[key]
|
| 393 |
+
else:
|
| 394 |
+
return default
|
| 395 |
+
|
| 396 |
+
def __setitem__(self, key: _K, value: _V) -> _T:
|
| 397 |
+
checked_key = self._key_checker.CheckValue(key)
|
| 398 |
+
checked_value = self._value_checker.CheckValue(value)
|
| 399 |
+
self._values[checked_key] = checked_value
|
| 400 |
+
self._message_listener.Modified()
|
| 401 |
+
|
| 402 |
+
def __delitem__(self, key: _K) -> None:
|
| 403 |
+
del self._values[key]
|
| 404 |
+
self._message_listener.Modified()
|
| 405 |
+
|
| 406 |
+
def __len__(self) -> int:
|
| 407 |
+
return len(self._values)
|
| 408 |
+
|
| 409 |
+
def __iter__(self) -> Iterator[_K]:
|
| 410 |
+
return iter(self._values)
|
| 411 |
+
|
| 412 |
+
def __repr__(self) -> str:
|
| 413 |
+
return repr(self._values)
|
| 414 |
+
|
| 415 |
+
def MergeFrom(self, other: 'ScalarMap[_K, _V]') -> None:
|
| 416 |
+
self._values.update(other._values)
|
| 417 |
+
self._message_listener.Modified()
|
| 418 |
+
|
| 419 |
+
def InvalidateIterators(self) -> None:
|
| 420 |
+
# It appears that the only way to reliably invalidate iterators to
|
| 421 |
+
# self._values is to ensure that its size changes.
|
| 422 |
+
original = self._values
|
| 423 |
+
self._values = original.copy()
|
| 424 |
+
original[None] = None
|
| 425 |
+
|
| 426 |
+
# This is defined in the abstract base, but we can do it much more cheaply.
|
| 427 |
+
def clear(self) -> None:
|
| 428 |
+
self._values.clear()
|
| 429 |
+
self._message_listener.Modified()
|
| 430 |
+
|
| 431 |
+
def GetEntryClass(self) -> Any:
|
| 432 |
+
return self._entry_descriptor._concrete_class
|
| 433 |
+
|
| 434 |
+
|
| 435 |
+
class MessageMap(MutableMapping[_K, _V]):
|
| 436 |
+
"""Simple, type-checked, dict-like container for with submessage values."""
|
| 437 |
+
|
| 438 |
+
# Disallows assignment to other attributes.
|
| 439 |
+
__slots__ = ['_key_checker', '_values', '_message_listener',
|
| 440 |
+
'_message_descriptor', '_entry_descriptor']
|
| 441 |
+
|
| 442 |
+
def __init__(
|
| 443 |
+
self,
|
| 444 |
+
message_listener: Any,
|
| 445 |
+
message_descriptor: Any,
|
| 446 |
+
key_checker: Any,
|
| 447 |
+
entry_descriptor: Any,
|
| 448 |
+
) -> None:
|
| 449 |
+
"""
|
| 450 |
+
Args:
|
| 451 |
+
message_listener: A MessageListener implementation.
|
| 452 |
+
The ScalarMap will call this object's Modified() method when it
|
| 453 |
+
is modified.
|
| 454 |
+
key_checker: A type_checkers.ValueChecker instance to run on keys
|
| 455 |
+
inserted into this container.
|
| 456 |
+
value_checker: A type_checkers.ValueChecker instance to run on values
|
| 457 |
+
inserted into this container.
|
| 458 |
+
entry_descriptor: The MessageDescriptor of a map entry: key and value.
|
| 459 |
+
"""
|
| 460 |
+
self._message_listener = message_listener
|
| 461 |
+
self._message_descriptor = message_descriptor
|
| 462 |
+
self._key_checker = key_checker
|
| 463 |
+
self._entry_descriptor = entry_descriptor
|
| 464 |
+
self._values = {}
|
| 465 |
+
|
| 466 |
+
def __getitem__(self, key: _K) -> _V:
|
| 467 |
+
key = self._key_checker.CheckValue(key)
|
| 468 |
+
try:
|
| 469 |
+
return self._values[key]
|
| 470 |
+
except KeyError:
|
| 471 |
+
new_element = self._message_descriptor._concrete_class()
|
| 472 |
+
new_element._SetListener(self._message_listener)
|
| 473 |
+
self._values[key] = new_element
|
| 474 |
+
self._message_listener.Modified()
|
| 475 |
+
return new_element
|
| 476 |
+
|
| 477 |
+
def get_or_create(self, key: _K) -> _V:
|
| 478 |
+
"""get_or_create() is an alias for getitem (ie. map[key]).
|
| 479 |
+
|
| 480 |
+
Args:
|
| 481 |
+
key: The key to get or create in the map.
|
| 482 |
+
|
| 483 |
+
This is useful in cases where you want to be explicit that the call is
|
| 484 |
+
mutating the map. This can avoid lint errors for statements like this
|
| 485 |
+
that otherwise would appear to be pointless statements:
|
| 486 |
+
|
| 487 |
+
msg.my_map[key]
|
| 488 |
+
"""
|
| 489 |
+
return self[key]
|
| 490 |
+
|
| 491 |
+
@overload
|
| 492 |
+
def get(self, key: _K) -> Optional[_V]:
|
| 493 |
+
...
|
| 494 |
+
|
| 495 |
+
@overload
|
| 496 |
+
def get(self, key: _K, default: _T) -> Union[_V, _T]:
|
| 497 |
+
...
|
| 498 |
+
|
| 499 |
+
# We need to override this explicitly, because our defaultdict-like behavior
|
| 500 |
+
# will make the default implementation (from our base class) always insert
|
| 501 |
+
# the key.
|
| 502 |
+
def get(self, key, default=None):
|
| 503 |
+
if key in self:
|
| 504 |
+
return self[key]
|
| 505 |
+
else:
|
| 506 |
+
return default
|
| 507 |
+
|
| 508 |
+
def __contains__(self, item: _K) -> bool:
|
| 509 |
+
item = self._key_checker.CheckValue(item)
|
| 510 |
+
return item in self._values
|
| 511 |
+
|
| 512 |
+
def __setitem__(self, key: _K, value: _V) -> NoReturn:
|
| 513 |
+
raise ValueError('May not set values directly, call my_map[key].foo = 5')
|
| 514 |
+
|
| 515 |
+
def __delitem__(self, key: _K) -> None:
|
| 516 |
+
key = self._key_checker.CheckValue(key)
|
| 517 |
+
del self._values[key]
|
| 518 |
+
self._message_listener.Modified()
|
| 519 |
+
|
| 520 |
+
def __len__(self) -> int:
|
| 521 |
+
return len(self._values)
|
| 522 |
+
|
| 523 |
+
def __iter__(self) -> Iterator[_K]:
|
| 524 |
+
return iter(self._values)
|
| 525 |
+
|
| 526 |
+
def __repr__(self) -> str:
|
| 527 |
+
return repr(self._values)
|
| 528 |
+
|
| 529 |
+
def MergeFrom(self, other: 'MessageMap[_K, _V]') -> None:
|
| 530 |
+
# pylint: disable=protected-access
|
| 531 |
+
for key in other._values:
|
| 532 |
+
# According to documentation: "When parsing from the wire or when merging,
|
| 533 |
+
# if there are duplicate map keys the last key seen is used".
|
| 534 |
+
if key in self:
|
| 535 |
+
del self[key]
|
| 536 |
+
self[key].CopyFrom(other[key])
|
| 537 |
+
# self._message_listener.Modified() not required here, because
|
| 538 |
+
# mutations to submessages already propagate.
|
| 539 |
+
|
| 540 |
+
def InvalidateIterators(self) -> None:
|
| 541 |
+
# It appears that the only way to reliably invalidate iterators to
|
| 542 |
+
# self._values is to ensure that its size changes.
|
| 543 |
+
original = self._values
|
| 544 |
+
self._values = original.copy()
|
| 545 |
+
original[None] = None
|
| 546 |
+
|
| 547 |
+
# This is defined in the abstract base, but we can do it much more cheaply.
|
| 548 |
+
def clear(self) -> None:
|
| 549 |
+
self._values.clear()
|
| 550 |
+
self._message_listener.Modified()
|
| 551 |
+
|
| 552 |
+
def GetEntryClass(self) -> Any:
|
| 553 |
+
return self._entry_descriptor._concrete_class
|
| 554 |
+
|
| 555 |
+
|
| 556 |
+
class _UnknownField:
|
| 557 |
+
"""A parsed unknown field."""
|
| 558 |
+
|
| 559 |
+
# Disallows assignment to other attributes.
|
| 560 |
+
__slots__ = ['_field_number', '_wire_type', '_data']
|
| 561 |
+
|
| 562 |
+
def __init__(self, field_number, wire_type, data):
|
| 563 |
+
self._field_number = field_number
|
| 564 |
+
self._wire_type = wire_type
|
| 565 |
+
self._data = data
|
| 566 |
+
return
|
| 567 |
+
|
| 568 |
+
def __lt__(self, other):
|
| 569 |
+
# pylint: disable=protected-access
|
| 570 |
+
return self._field_number < other._field_number
|
| 571 |
+
|
| 572 |
+
def __eq__(self, other):
|
| 573 |
+
if self is other:
|
| 574 |
+
return True
|
| 575 |
+
# pylint: disable=protected-access
|
| 576 |
+
return (self._field_number == other._field_number and
|
| 577 |
+
self._wire_type == other._wire_type and
|
| 578 |
+
self._data == other._data)
|
| 579 |
+
|
| 580 |
+
|
| 581 |
+
class UnknownFieldRef: # pylint: disable=missing-class-docstring
|
| 582 |
+
|
| 583 |
+
def __init__(self, parent, index):
|
| 584 |
+
self._parent = parent
|
| 585 |
+
self._index = index
|
| 586 |
+
|
| 587 |
+
def _check_valid(self):
|
| 588 |
+
if not self._parent:
|
| 589 |
+
raise ValueError('UnknownField does not exist. '
|
| 590 |
+
'The parent message might be cleared.')
|
| 591 |
+
if self._index >= len(self._parent):
|
| 592 |
+
raise ValueError('UnknownField does not exist. '
|
| 593 |
+
'The parent message might be cleared.')
|
| 594 |
+
|
| 595 |
+
@property
|
| 596 |
+
def field_number(self):
|
| 597 |
+
self._check_valid()
|
| 598 |
+
# pylint: disable=protected-access
|
| 599 |
+
return self._parent._internal_get(self._index)._field_number
|
| 600 |
+
|
| 601 |
+
@property
|
| 602 |
+
def wire_type(self):
|
| 603 |
+
self._check_valid()
|
| 604 |
+
# pylint: disable=protected-access
|
| 605 |
+
return self._parent._internal_get(self._index)._wire_type
|
| 606 |
+
|
| 607 |
+
@property
|
| 608 |
+
def data(self):
|
| 609 |
+
self._check_valid()
|
| 610 |
+
# pylint: disable=protected-access
|
| 611 |
+
return self._parent._internal_get(self._index)._data
|
| 612 |
+
|
| 613 |
+
|
| 614 |
+
class UnknownFieldSet:
|
| 615 |
+
"""UnknownField container"""
|
| 616 |
+
|
| 617 |
+
# Disallows assignment to other attributes.
|
| 618 |
+
__slots__ = ['_values']
|
| 619 |
+
|
| 620 |
+
def __init__(self):
|
| 621 |
+
self._values = []
|
| 622 |
+
|
| 623 |
+
def __getitem__(self, index):
|
| 624 |
+
if self._values is None:
|
| 625 |
+
raise ValueError('UnknownFields does not exist. '
|
| 626 |
+
'The parent message might be cleared.')
|
| 627 |
+
size = len(self._values)
|
| 628 |
+
if index < 0:
|
| 629 |
+
index += size
|
| 630 |
+
if index < 0 or index >= size:
|
| 631 |
+
raise IndexError('index %d out of range'.index)
|
| 632 |
+
|
| 633 |
+
return UnknownFieldRef(self, index)
|
| 634 |
+
|
| 635 |
+
def _internal_get(self, index):
|
| 636 |
+
return self._values[index]
|
| 637 |
+
|
| 638 |
+
def __len__(self):
|
| 639 |
+
if self._values is None:
|
| 640 |
+
raise ValueError('UnknownFields does not exist. '
|
| 641 |
+
'The parent message might be cleared.')
|
| 642 |
+
return len(self._values)
|
| 643 |
+
|
| 644 |
+
def _add(self, field_number, wire_type, data):
|
| 645 |
+
unknown_field = _UnknownField(field_number, wire_type, data)
|
| 646 |
+
self._values.append(unknown_field)
|
| 647 |
+
return unknown_field
|
| 648 |
+
|
| 649 |
+
def __iter__(self):
|
| 650 |
+
for i in range(len(self)):
|
| 651 |
+
yield UnknownFieldRef(self, i)
|
| 652 |
+
|
| 653 |
+
def _extend(self, other):
|
| 654 |
+
if other is None:
|
| 655 |
+
return
|
| 656 |
+
# pylint: disable=protected-access
|
| 657 |
+
self._values.extend(other._values)
|
| 658 |
+
|
| 659 |
+
def __eq__(self, other):
|
| 660 |
+
if self is other:
|
| 661 |
+
return True
|
| 662 |
+
# Sort unknown fields because their order shouldn't
|
| 663 |
+
# affect equality test.
|
| 664 |
+
values = list(self._values)
|
| 665 |
+
if other is None:
|
| 666 |
+
return not values
|
| 667 |
+
values.sort()
|
| 668 |
+
# pylint: disable=protected-access
|
| 669 |
+
other_values = sorted(other._values)
|
| 670 |
+
return values == other_values
|
| 671 |
+
|
| 672 |
+
def _clear(self):
|
| 673 |
+
for value in self._values:
|
| 674 |
+
# pylint: disable=protected-access
|
| 675 |
+
if isinstance(value._data, UnknownFieldSet):
|
| 676 |
+
value._data._clear() # pylint: disable=protected-access
|
| 677 |
+
self._values = None
|
.venv/lib/python3.11/site-packages/google/protobuf/internal/encoder.py
ADDED
|
@@ -0,0 +1,806 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Protocol Buffers - Google's data interchange format
|
| 2 |
+
# Copyright 2008 Google Inc. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Use of this source code is governed by a BSD-style
|
| 5 |
+
# license that can be found in the LICENSE file or at
|
| 6 |
+
# https://developers.google.com/open-source/licenses/bsd
|
| 7 |
+
|
| 8 |
+
"""Code for encoding protocol message primitives.
|
| 9 |
+
|
| 10 |
+
Contains the logic for encoding every logical protocol field type
|
| 11 |
+
into one of the 5 physical wire types.
|
| 12 |
+
|
| 13 |
+
This code is designed to push the Python interpreter's performance to the
|
| 14 |
+
limits.
|
| 15 |
+
|
| 16 |
+
The basic idea is that at startup time, for every field (i.e. every
|
| 17 |
+
FieldDescriptor) we construct two functions: a "sizer" and an "encoder". The
|
| 18 |
+
sizer takes a value of this field's type and computes its byte size. The
|
| 19 |
+
encoder takes a writer function and a value. It encodes the value into byte
|
| 20 |
+
strings and invokes the writer function to write those strings. Typically the
|
| 21 |
+
writer function is the write() method of a BytesIO.
|
| 22 |
+
|
| 23 |
+
We try to do as much work as possible when constructing the writer and the
|
| 24 |
+
sizer rather than when calling them. In particular:
|
| 25 |
+
* We copy any needed global functions to local variables, so that we do not need
|
| 26 |
+
to do costly global table lookups at runtime.
|
| 27 |
+
* Similarly, we try to do any attribute lookups at startup time if possible.
|
| 28 |
+
* Every field's tag is encoded to bytes at startup, since it can't change at
|
| 29 |
+
runtime.
|
| 30 |
+
* Whatever component of the field size we can compute at startup, we do.
|
| 31 |
+
* We *avoid* sharing code if doing so would make the code slower and not sharing
|
| 32 |
+
does not burden us too much. For example, encoders for repeated fields do
|
| 33 |
+
not just call the encoders for singular fields in a loop because this would
|
| 34 |
+
add an extra function call overhead for every loop iteration; instead, we
|
| 35 |
+
manually inline the single-value encoder into the loop.
|
| 36 |
+
* If a Python function lacks a return statement, Python actually generates
|
| 37 |
+
instructions to pop the result of the last statement off the stack, push
|
| 38 |
+
None onto the stack, and then return that. If we really don't care what
|
| 39 |
+
value is returned, then we can save two instructions by returning the
|
| 40 |
+
result of the last statement. It looks funny but it helps.
|
| 41 |
+
* We assume that type and bounds checking has happened at a higher level.
|
| 42 |
+
"""
|
| 43 |
+
|
| 44 |
+
__author__ = 'kenton@google.com (Kenton Varda)'
|
| 45 |
+
|
| 46 |
+
import struct
|
| 47 |
+
|
| 48 |
+
from google.protobuf.internal import wire_format
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
# This will overflow and thus become IEEE-754 "infinity". We would use
|
| 52 |
+
# "float('inf')" but it doesn't work on Windows pre-Python-2.6.
|
| 53 |
+
_POS_INF = 1e10000
|
| 54 |
+
_NEG_INF = -_POS_INF
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def _VarintSize(value):
|
| 58 |
+
"""Compute the size of a varint value."""
|
| 59 |
+
if value <= 0x7f: return 1
|
| 60 |
+
if value <= 0x3fff: return 2
|
| 61 |
+
if value <= 0x1fffff: return 3
|
| 62 |
+
if value <= 0xfffffff: return 4
|
| 63 |
+
if value <= 0x7ffffffff: return 5
|
| 64 |
+
if value <= 0x3ffffffffff: return 6
|
| 65 |
+
if value <= 0x1ffffffffffff: return 7
|
| 66 |
+
if value <= 0xffffffffffffff: return 8
|
| 67 |
+
if value <= 0x7fffffffffffffff: return 9
|
| 68 |
+
return 10
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def _SignedVarintSize(value):
|
| 72 |
+
"""Compute the size of a signed varint value."""
|
| 73 |
+
if value < 0: return 10
|
| 74 |
+
if value <= 0x7f: return 1
|
| 75 |
+
if value <= 0x3fff: return 2
|
| 76 |
+
if value <= 0x1fffff: return 3
|
| 77 |
+
if value <= 0xfffffff: return 4
|
| 78 |
+
if value <= 0x7ffffffff: return 5
|
| 79 |
+
if value <= 0x3ffffffffff: return 6
|
| 80 |
+
if value <= 0x1ffffffffffff: return 7
|
| 81 |
+
if value <= 0xffffffffffffff: return 8
|
| 82 |
+
if value <= 0x7fffffffffffffff: return 9
|
| 83 |
+
return 10
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def _TagSize(field_number):
|
| 87 |
+
"""Returns the number of bytes required to serialize a tag with this field
|
| 88 |
+
number."""
|
| 89 |
+
# Just pass in type 0, since the type won't affect the tag+type size.
|
| 90 |
+
return _VarintSize(wire_format.PackTag(field_number, 0))
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
# --------------------------------------------------------------------
|
| 94 |
+
# In this section we define some generic sizers. Each of these functions
|
| 95 |
+
# takes parameters specific to a particular field type, e.g. int32 or fixed64.
|
| 96 |
+
# It returns another function which in turn takes parameters specific to a
|
| 97 |
+
# particular field, e.g. the field number and whether it is repeated or packed.
|
| 98 |
+
# Look at the next section to see how these are used.
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def _SimpleSizer(compute_value_size):
|
| 102 |
+
"""A sizer which uses the function compute_value_size to compute the size of
|
| 103 |
+
each value. Typically compute_value_size is _VarintSize."""
|
| 104 |
+
|
| 105 |
+
def SpecificSizer(field_number, is_repeated, is_packed):
|
| 106 |
+
tag_size = _TagSize(field_number)
|
| 107 |
+
if is_packed:
|
| 108 |
+
local_VarintSize = _VarintSize
|
| 109 |
+
def PackedFieldSize(value):
|
| 110 |
+
result = 0
|
| 111 |
+
for element in value:
|
| 112 |
+
result += compute_value_size(element)
|
| 113 |
+
return result + local_VarintSize(result) + tag_size
|
| 114 |
+
return PackedFieldSize
|
| 115 |
+
elif is_repeated:
|
| 116 |
+
def RepeatedFieldSize(value):
|
| 117 |
+
result = tag_size * len(value)
|
| 118 |
+
for element in value:
|
| 119 |
+
result += compute_value_size(element)
|
| 120 |
+
return result
|
| 121 |
+
return RepeatedFieldSize
|
| 122 |
+
else:
|
| 123 |
+
def FieldSize(value):
|
| 124 |
+
return tag_size + compute_value_size(value)
|
| 125 |
+
return FieldSize
|
| 126 |
+
|
| 127 |
+
return SpecificSizer
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def _ModifiedSizer(compute_value_size, modify_value):
|
| 131 |
+
"""Like SimpleSizer, but modify_value is invoked on each value before it is
|
| 132 |
+
passed to compute_value_size. modify_value is typically ZigZagEncode."""
|
| 133 |
+
|
| 134 |
+
def SpecificSizer(field_number, is_repeated, is_packed):
|
| 135 |
+
tag_size = _TagSize(field_number)
|
| 136 |
+
if is_packed:
|
| 137 |
+
local_VarintSize = _VarintSize
|
| 138 |
+
def PackedFieldSize(value):
|
| 139 |
+
result = 0
|
| 140 |
+
for element in value:
|
| 141 |
+
result += compute_value_size(modify_value(element))
|
| 142 |
+
return result + local_VarintSize(result) + tag_size
|
| 143 |
+
return PackedFieldSize
|
| 144 |
+
elif is_repeated:
|
| 145 |
+
def RepeatedFieldSize(value):
|
| 146 |
+
result = tag_size * len(value)
|
| 147 |
+
for element in value:
|
| 148 |
+
result += compute_value_size(modify_value(element))
|
| 149 |
+
return result
|
| 150 |
+
return RepeatedFieldSize
|
| 151 |
+
else:
|
| 152 |
+
def FieldSize(value):
|
| 153 |
+
return tag_size + compute_value_size(modify_value(value))
|
| 154 |
+
return FieldSize
|
| 155 |
+
|
| 156 |
+
return SpecificSizer
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
def _FixedSizer(value_size):
|
| 160 |
+
"""Like _SimpleSizer except for a fixed-size field. The input is the size
|
| 161 |
+
of one value."""
|
| 162 |
+
|
| 163 |
+
def SpecificSizer(field_number, is_repeated, is_packed):
|
| 164 |
+
tag_size = _TagSize(field_number)
|
| 165 |
+
if is_packed:
|
| 166 |
+
local_VarintSize = _VarintSize
|
| 167 |
+
def PackedFieldSize(value):
|
| 168 |
+
result = len(value) * value_size
|
| 169 |
+
return result + local_VarintSize(result) + tag_size
|
| 170 |
+
return PackedFieldSize
|
| 171 |
+
elif is_repeated:
|
| 172 |
+
element_size = value_size + tag_size
|
| 173 |
+
def RepeatedFieldSize(value):
|
| 174 |
+
return len(value) * element_size
|
| 175 |
+
return RepeatedFieldSize
|
| 176 |
+
else:
|
| 177 |
+
field_size = value_size + tag_size
|
| 178 |
+
def FieldSize(value):
|
| 179 |
+
return field_size
|
| 180 |
+
return FieldSize
|
| 181 |
+
|
| 182 |
+
return SpecificSizer
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
# ====================================================================
|
| 186 |
+
# Here we declare a sizer constructor for each field type. Each "sizer
|
| 187 |
+
# constructor" is a function that takes (field_number, is_repeated, is_packed)
|
| 188 |
+
# as parameters and returns a sizer, which in turn takes a field value as
|
| 189 |
+
# a parameter and returns its encoded size.
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
Int32Sizer = Int64Sizer = EnumSizer = _SimpleSizer(_SignedVarintSize)
|
| 193 |
+
|
| 194 |
+
UInt32Sizer = UInt64Sizer = _SimpleSizer(_VarintSize)
|
| 195 |
+
|
| 196 |
+
SInt32Sizer = SInt64Sizer = _ModifiedSizer(
|
| 197 |
+
_SignedVarintSize, wire_format.ZigZagEncode)
|
| 198 |
+
|
| 199 |
+
Fixed32Sizer = SFixed32Sizer = FloatSizer = _FixedSizer(4)
|
| 200 |
+
Fixed64Sizer = SFixed64Sizer = DoubleSizer = _FixedSizer(8)
|
| 201 |
+
|
| 202 |
+
BoolSizer = _FixedSizer(1)
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
def StringSizer(field_number, is_repeated, is_packed):
|
| 206 |
+
"""Returns a sizer for a string field."""
|
| 207 |
+
|
| 208 |
+
tag_size = _TagSize(field_number)
|
| 209 |
+
local_VarintSize = _VarintSize
|
| 210 |
+
local_len = len
|
| 211 |
+
assert not is_packed
|
| 212 |
+
if is_repeated:
|
| 213 |
+
def RepeatedFieldSize(value):
|
| 214 |
+
result = tag_size * len(value)
|
| 215 |
+
for element in value:
|
| 216 |
+
l = local_len(element.encode('utf-8'))
|
| 217 |
+
result += local_VarintSize(l) + l
|
| 218 |
+
return result
|
| 219 |
+
return RepeatedFieldSize
|
| 220 |
+
else:
|
| 221 |
+
def FieldSize(value):
|
| 222 |
+
l = local_len(value.encode('utf-8'))
|
| 223 |
+
return tag_size + local_VarintSize(l) + l
|
| 224 |
+
return FieldSize
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
def BytesSizer(field_number, is_repeated, is_packed):
|
| 228 |
+
"""Returns a sizer for a bytes field."""
|
| 229 |
+
|
| 230 |
+
tag_size = _TagSize(field_number)
|
| 231 |
+
local_VarintSize = _VarintSize
|
| 232 |
+
local_len = len
|
| 233 |
+
assert not is_packed
|
| 234 |
+
if is_repeated:
|
| 235 |
+
def RepeatedFieldSize(value):
|
| 236 |
+
result = tag_size * len(value)
|
| 237 |
+
for element in value:
|
| 238 |
+
l = local_len(element)
|
| 239 |
+
result += local_VarintSize(l) + l
|
| 240 |
+
return result
|
| 241 |
+
return RepeatedFieldSize
|
| 242 |
+
else:
|
| 243 |
+
def FieldSize(value):
|
| 244 |
+
l = local_len(value)
|
| 245 |
+
return tag_size + local_VarintSize(l) + l
|
| 246 |
+
return FieldSize
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
def GroupSizer(field_number, is_repeated, is_packed):
|
| 250 |
+
"""Returns a sizer for a group field."""
|
| 251 |
+
|
| 252 |
+
tag_size = _TagSize(field_number) * 2
|
| 253 |
+
assert not is_packed
|
| 254 |
+
if is_repeated:
|
| 255 |
+
def RepeatedFieldSize(value):
|
| 256 |
+
result = tag_size * len(value)
|
| 257 |
+
for element in value:
|
| 258 |
+
result += element.ByteSize()
|
| 259 |
+
return result
|
| 260 |
+
return RepeatedFieldSize
|
| 261 |
+
else:
|
| 262 |
+
def FieldSize(value):
|
| 263 |
+
return tag_size + value.ByteSize()
|
| 264 |
+
return FieldSize
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
def MessageSizer(field_number, is_repeated, is_packed):
|
| 268 |
+
"""Returns a sizer for a message field."""
|
| 269 |
+
|
| 270 |
+
tag_size = _TagSize(field_number)
|
| 271 |
+
local_VarintSize = _VarintSize
|
| 272 |
+
assert not is_packed
|
| 273 |
+
if is_repeated:
|
| 274 |
+
def RepeatedFieldSize(value):
|
| 275 |
+
result = tag_size * len(value)
|
| 276 |
+
for element in value:
|
| 277 |
+
l = element.ByteSize()
|
| 278 |
+
result += local_VarintSize(l) + l
|
| 279 |
+
return result
|
| 280 |
+
return RepeatedFieldSize
|
| 281 |
+
else:
|
| 282 |
+
def FieldSize(value):
|
| 283 |
+
l = value.ByteSize()
|
| 284 |
+
return tag_size + local_VarintSize(l) + l
|
| 285 |
+
return FieldSize
|
| 286 |
+
|
| 287 |
+
|
| 288 |
+
# --------------------------------------------------------------------
|
| 289 |
+
# MessageSet is special: it needs custom logic to compute its size properly.
|
| 290 |
+
|
| 291 |
+
|
| 292 |
+
def MessageSetItemSizer(field_number):
|
| 293 |
+
"""Returns a sizer for extensions of MessageSet.
|
| 294 |
+
|
| 295 |
+
The message set message looks like this:
|
| 296 |
+
message MessageSet {
|
| 297 |
+
repeated group Item = 1 {
|
| 298 |
+
required int32 type_id = 2;
|
| 299 |
+
required string message = 3;
|
| 300 |
+
}
|
| 301 |
+
}
|
| 302 |
+
"""
|
| 303 |
+
static_size = (_TagSize(1) * 2 + _TagSize(2) + _VarintSize(field_number) +
|
| 304 |
+
_TagSize(3))
|
| 305 |
+
local_VarintSize = _VarintSize
|
| 306 |
+
|
| 307 |
+
def FieldSize(value):
|
| 308 |
+
l = value.ByteSize()
|
| 309 |
+
return static_size + local_VarintSize(l) + l
|
| 310 |
+
|
| 311 |
+
return FieldSize
|
| 312 |
+
|
| 313 |
+
|
| 314 |
+
# --------------------------------------------------------------------
|
| 315 |
+
# Map is special: it needs custom logic to compute its size properly.
|
| 316 |
+
|
| 317 |
+
|
| 318 |
+
def MapSizer(field_descriptor, is_message_map):
|
| 319 |
+
"""Returns a sizer for a map field."""
|
| 320 |
+
|
| 321 |
+
# Can't look at field_descriptor.message_type._concrete_class because it may
|
| 322 |
+
# not have been initialized yet.
|
| 323 |
+
message_type = field_descriptor.message_type
|
| 324 |
+
message_sizer = MessageSizer(field_descriptor.number, False, False)
|
| 325 |
+
|
| 326 |
+
def FieldSize(map_value):
|
| 327 |
+
total = 0
|
| 328 |
+
for key in map_value:
|
| 329 |
+
value = map_value[key]
|
| 330 |
+
# It's wasteful to create the messages and throw them away one second
|
| 331 |
+
# later since we'll do the same for the actual encode. But there's not an
|
| 332 |
+
# obvious way to avoid this within the current design without tons of code
|
| 333 |
+
# duplication. For message map, value.ByteSize() should be called to
|
| 334 |
+
# update the status.
|
| 335 |
+
entry_msg = message_type._concrete_class(key=key, value=value)
|
| 336 |
+
total += message_sizer(entry_msg)
|
| 337 |
+
if is_message_map:
|
| 338 |
+
value.ByteSize()
|
| 339 |
+
return total
|
| 340 |
+
|
| 341 |
+
return FieldSize
|
| 342 |
+
|
| 343 |
+
# ====================================================================
|
| 344 |
+
# Encoders!
|
| 345 |
+
|
| 346 |
+
|
| 347 |
+
def _VarintEncoder():
|
| 348 |
+
"""Return an encoder for a basic varint value (does not include tag)."""
|
| 349 |
+
|
| 350 |
+
local_int2byte = struct.Struct('>B').pack
|
| 351 |
+
|
| 352 |
+
def EncodeVarint(write, value, unused_deterministic=None):
|
| 353 |
+
bits = value & 0x7f
|
| 354 |
+
value >>= 7
|
| 355 |
+
while value:
|
| 356 |
+
write(local_int2byte(0x80|bits))
|
| 357 |
+
bits = value & 0x7f
|
| 358 |
+
value >>= 7
|
| 359 |
+
return write(local_int2byte(bits))
|
| 360 |
+
|
| 361 |
+
return EncodeVarint
|
| 362 |
+
|
| 363 |
+
|
| 364 |
+
def _SignedVarintEncoder():
|
| 365 |
+
"""Return an encoder for a basic signed varint value (does not include
|
| 366 |
+
tag)."""
|
| 367 |
+
|
| 368 |
+
local_int2byte = struct.Struct('>B').pack
|
| 369 |
+
|
| 370 |
+
def EncodeSignedVarint(write, value, unused_deterministic=None):
|
| 371 |
+
if value < 0:
|
| 372 |
+
value += (1 << 64)
|
| 373 |
+
bits = value & 0x7f
|
| 374 |
+
value >>= 7
|
| 375 |
+
while value:
|
| 376 |
+
write(local_int2byte(0x80|bits))
|
| 377 |
+
bits = value & 0x7f
|
| 378 |
+
value >>= 7
|
| 379 |
+
return write(local_int2byte(bits))
|
| 380 |
+
|
| 381 |
+
return EncodeSignedVarint
|
| 382 |
+
|
| 383 |
+
|
| 384 |
+
_EncodeVarint = _VarintEncoder()
|
| 385 |
+
_EncodeSignedVarint = _SignedVarintEncoder()
|
| 386 |
+
|
| 387 |
+
|
| 388 |
+
def _VarintBytes(value):
|
| 389 |
+
"""Encode the given integer as a varint and return the bytes. This is only
|
| 390 |
+
called at startup time so it doesn't need to be fast."""
|
| 391 |
+
|
| 392 |
+
pieces = []
|
| 393 |
+
_EncodeVarint(pieces.append, value, True)
|
| 394 |
+
return b"".join(pieces)
|
| 395 |
+
|
| 396 |
+
|
| 397 |
+
def TagBytes(field_number, wire_type):
|
| 398 |
+
"""Encode the given tag and return the bytes. Only called at startup."""
|
| 399 |
+
|
| 400 |
+
return bytes(_VarintBytes(wire_format.PackTag(field_number, wire_type)))
|
| 401 |
+
|
| 402 |
+
# --------------------------------------------------------------------
|
| 403 |
+
# As with sizers (see above), we have a number of common encoder
|
| 404 |
+
# implementations.
|
| 405 |
+
|
| 406 |
+
|
| 407 |
+
def _SimpleEncoder(wire_type, encode_value, compute_value_size):
|
| 408 |
+
"""Return a constructor for an encoder for fields of a particular type.
|
| 409 |
+
|
| 410 |
+
Args:
|
| 411 |
+
wire_type: The field's wire type, for encoding tags.
|
| 412 |
+
encode_value: A function which encodes an individual value, e.g.
|
| 413 |
+
_EncodeVarint().
|
| 414 |
+
compute_value_size: A function which computes the size of an individual
|
| 415 |
+
value, e.g. _VarintSize().
|
| 416 |
+
"""
|
| 417 |
+
|
| 418 |
+
def SpecificEncoder(field_number, is_repeated, is_packed):
|
| 419 |
+
if is_packed:
|
| 420 |
+
tag_bytes = TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED)
|
| 421 |
+
local_EncodeVarint = _EncodeVarint
|
| 422 |
+
def EncodePackedField(write, value, deterministic):
|
| 423 |
+
write(tag_bytes)
|
| 424 |
+
size = 0
|
| 425 |
+
for element in value:
|
| 426 |
+
size += compute_value_size(element)
|
| 427 |
+
local_EncodeVarint(write, size, deterministic)
|
| 428 |
+
for element in value:
|
| 429 |
+
encode_value(write, element, deterministic)
|
| 430 |
+
return EncodePackedField
|
| 431 |
+
elif is_repeated:
|
| 432 |
+
tag_bytes = TagBytes(field_number, wire_type)
|
| 433 |
+
def EncodeRepeatedField(write, value, deterministic):
|
| 434 |
+
for element in value:
|
| 435 |
+
write(tag_bytes)
|
| 436 |
+
encode_value(write, element, deterministic)
|
| 437 |
+
return EncodeRepeatedField
|
| 438 |
+
else:
|
| 439 |
+
tag_bytes = TagBytes(field_number, wire_type)
|
| 440 |
+
def EncodeField(write, value, deterministic):
|
| 441 |
+
write(tag_bytes)
|
| 442 |
+
return encode_value(write, value, deterministic)
|
| 443 |
+
return EncodeField
|
| 444 |
+
|
| 445 |
+
return SpecificEncoder
|
| 446 |
+
|
| 447 |
+
|
| 448 |
+
def _ModifiedEncoder(wire_type, encode_value, compute_value_size, modify_value):
|
| 449 |
+
"""Like SimpleEncoder but additionally invokes modify_value on every value
|
| 450 |
+
before passing it to encode_value. Usually modify_value is ZigZagEncode."""
|
| 451 |
+
|
| 452 |
+
def SpecificEncoder(field_number, is_repeated, is_packed):
|
| 453 |
+
if is_packed:
|
| 454 |
+
tag_bytes = TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED)
|
| 455 |
+
local_EncodeVarint = _EncodeVarint
|
| 456 |
+
def EncodePackedField(write, value, deterministic):
|
| 457 |
+
write(tag_bytes)
|
| 458 |
+
size = 0
|
| 459 |
+
for element in value:
|
| 460 |
+
size += compute_value_size(modify_value(element))
|
| 461 |
+
local_EncodeVarint(write, size, deterministic)
|
| 462 |
+
for element in value:
|
| 463 |
+
encode_value(write, modify_value(element), deterministic)
|
| 464 |
+
return EncodePackedField
|
| 465 |
+
elif is_repeated:
|
| 466 |
+
tag_bytes = TagBytes(field_number, wire_type)
|
| 467 |
+
def EncodeRepeatedField(write, value, deterministic):
|
| 468 |
+
for element in value:
|
| 469 |
+
write(tag_bytes)
|
| 470 |
+
encode_value(write, modify_value(element), deterministic)
|
| 471 |
+
return EncodeRepeatedField
|
| 472 |
+
else:
|
| 473 |
+
tag_bytes = TagBytes(field_number, wire_type)
|
| 474 |
+
def EncodeField(write, value, deterministic):
|
| 475 |
+
write(tag_bytes)
|
| 476 |
+
return encode_value(write, modify_value(value), deterministic)
|
| 477 |
+
return EncodeField
|
| 478 |
+
|
| 479 |
+
return SpecificEncoder
|
| 480 |
+
|
| 481 |
+
|
| 482 |
+
def _StructPackEncoder(wire_type, format):
|
| 483 |
+
"""Return a constructor for an encoder for a fixed-width field.
|
| 484 |
+
|
| 485 |
+
Args:
|
| 486 |
+
wire_type: The field's wire type, for encoding tags.
|
| 487 |
+
format: The format string to pass to struct.pack().
|
| 488 |
+
"""
|
| 489 |
+
|
| 490 |
+
value_size = struct.calcsize(format)
|
| 491 |
+
|
| 492 |
+
def SpecificEncoder(field_number, is_repeated, is_packed):
|
| 493 |
+
local_struct_pack = struct.pack
|
| 494 |
+
if is_packed:
|
| 495 |
+
tag_bytes = TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED)
|
| 496 |
+
local_EncodeVarint = _EncodeVarint
|
| 497 |
+
def EncodePackedField(write, value, deterministic):
|
| 498 |
+
write(tag_bytes)
|
| 499 |
+
local_EncodeVarint(write, len(value) * value_size, deterministic)
|
| 500 |
+
for element in value:
|
| 501 |
+
write(local_struct_pack(format, element))
|
| 502 |
+
return EncodePackedField
|
| 503 |
+
elif is_repeated:
|
| 504 |
+
tag_bytes = TagBytes(field_number, wire_type)
|
| 505 |
+
def EncodeRepeatedField(write, value, unused_deterministic=None):
|
| 506 |
+
for element in value:
|
| 507 |
+
write(tag_bytes)
|
| 508 |
+
write(local_struct_pack(format, element))
|
| 509 |
+
return EncodeRepeatedField
|
| 510 |
+
else:
|
| 511 |
+
tag_bytes = TagBytes(field_number, wire_type)
|
| 512 |
+
def EncodeField(write, value, unused_deterministic=None):
|
| 513 |
+
write(tag_bytes)
|
| 514 |
+
return write(local_struct_pack(format, value))
|
| 515 |
+
return EncodeField
|
| 516 |
+
|
| 517 |
+
return SpecificEncoder
|
| 518 |
+
|
| 519 |
+
|
| 520 |
+
def _FloatingPointEncoder(wire_type, format):
|
| 521 |
+
"""Return a constructor for an encoder for float fields.
|
| 522 |
+
|
| 523 |
+
This is like StructPackEncoder, but catches errors that may be due to
|
| 524 |
+
passing non-finite floating-point values to struct.pack, and makes a
|
| 525 |
+
second attempt to encode those values.
|
| 526 |
+
|
| 527 |
+
Args:
|
| 528 |
+
wire_type: The field's wire type, for encoding tags.
|
| 529 |
+
format: The format string to pass to struct.pack().
|
| 530 |
+
"""
|
| 531 |
+
|
| 532 |
+
value_size = struct.calcsize(format)
|
| 533 |
+
if value_size == 4:
|
| 534 |
+
def EncodeNonFiniteOrRaise(write, value):
|
| 535 |
+
# Remember that the serialized form uses little-endian byte order.
|
| 536 |
+
if value == _POS_INF:
|
| 537 |
+
write(b'\x00\x00\x80\x7F')
|
| 538 |
+
elif value == _NEG_INF:
|
| 539 |
+
write(b'\x00\x00\x80\xFF')
|
| 540 |
+
elif value != value: # NaN
|
| 541 |
+
write(b'\x00\x00\xC0\x7F')
|
| 542 |
+
else:
|
| 543 |
+
raise
|
| 544 |
+
elif value_size == 8:
|
| 545 |
+
def EncodeNonFiniteOrRaise(write, value):
|
| 546 |
+
if value == _POS_INF:
|
| 547 |
+
write(b'\x00\x00\x00\x00\x00\x00\xF0\x7F')
|
| 548 |
+
elif value == _NEG_INF:
|
| 549 |
+
write(b'\x00\x00\x00\x00\x00\x00\xF0\xFF')
|
| 550 |
+
elif value != value: # NaN
|
| 551 |
+
write(b'\x00\x00\x00\x00\x00\x00\xF8\x7F')
|
| 552 |
+
else:
|
| 553 |
+
raise
|
| 554 |
+
else:
|
| 555 |
+
raise ValueError('Can\'t encode floating-point values that are '
|
| 556 |
+
'%d bytes long (only 4 or 8)' % value_size)
|
| 557 |
+
|
| 558 |
+
def SpecificEncoder(field_number, is_repeated, is_packed):
|
| 559 |
+
local_struct_pack = struct.pack
|
| 560 |
+
if is_packed:
|
| 561 |
+
tag_bytes = TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED)
|
| 562 |
+
local_EncodeVarint = _EncodeVarint
|
| 563 |
+
def EncodePackedField(write, value, deterministic):
|
| 564 |
+
write(tag_bytes)
|
| 565 |
+
local_EncodeVarint(write, len(value) * value_size, deterministic)
|
| 566 |
+
for element in value:
|
| 567 |
+
# This try/except block is going to be faster than any code that
|
| 568 |
+
# we could write to check whether element is finite.
|
| 569 |
+
try:
|
| 570 |
+
write(local_struct_pack(format, element))
|
| 571 |
+
except SystemError:
|
| 572 |
+
EncodeNonFiniteOrRaise(write, element)
|
| 573 |
+
return EncodePackedField
|
| 574 |
+
elif is_repeated:
|
| 575 |
+
tag_bytes = TagBytes(field_number, wire_type)
|
| 576 |
+
def EncodeRepeatedField(write, value, unused_deterministic=None):
|
| 577 |
+
for element in value:
|
| 578 |
+
write(tag_bytes)
|
| 579 |
+
try:
|
| 580 |
+
write(local_struct_pack(format, element))
|
| 581 |
+
except SystemError:
|
| 582 |
+
EncodeNonFiniteOrRaise(write, element)
|
| 583 |
+
return EncodeRepeatedField
|
| 584 |
+
else:
|
| 585 |
+
tag_bytes = TagBytes(field_number, wire_type)
|
| 586 |
+
def EncodeField(write, value, unused_deterministic=None):
|
| 587 |
+
write(tag_bytes)
|
| 588 |
+
try:
|
| 589 |
+
write(local_struct_pack(format, value))
|
| 590 |
+
except SystemError:
|
| 591 |
+
EncodeNonFiniteOrRaise(write, value)
|
| 592 |
+
return EncodeField
|
| 593 |
+
|
| 594 |
+
return SpecificEncoder
|
| 595 |
+
|
| 596 |
+
|
| 597 |
+
# ====================================================================
|
| 598 |
+
# Here we declare an encoder constructor for each field type. These work
|
| 599 |
+
# very similarly to sizer constructors, described earlier.
|
| 600 |
+
|
| 601 |
+
|
| 602 |
+
Int32Encoder = Int64Encoder = EnumEncoder = _SimpleEncoder(
|
| 603 |
+
wire_format.WIRETYPE_VARINT, _EncodeSignedVarint, _SignedVarintSize)
|
| 604 |
+
|
| 605 |
+
UInt32Encoder = UInt64Encoder = _SimpleEncoder(
|
| 606 |
+
wire_format.WIRETYPE_VARINT, _EncodeVarint, _VarintSize)
|
| 607 |
+
|
| 608 |
+
SInt32Encoder = SInt64Encoder = _ModifiedEncoder(
|
| 609 |
+
wire_format.WIRETYPE_VARINT, _EncodeVarint, _VarintSize,
|
| 610 |
+
wire_format.ZigZagEncode)
|
| 611 |
+
|
| 612 |
+
# Note that Python conveniently guarantees that when using the '<' prefix on
|
| 613 |
+
# formats, they will also have the same size across all platforms (as opposed
|
| 614 |
+
# to without the prefix, where their sizes depend on the C compiler's basic
|
| 615 |
+
# type sizes).
|
| 616 |
+
Fixed32Encoder = _StructPackEncoder(wire_format.WIRETYPE_FIXED32, '<I')
|
| 617 |
+
Fixed64Encoder = _StructPackEncoder(wire_format.WIRETYPE_FIXED64, '<Q')
|
| 618 |
+
SFixed32Encoder = _StructPackEncoder(wire_format.WIRETYPE_FIXED32, '<i')
|
| 619 |
+
SFixed64Encoder = _StructPackEncoder(wire_format.WIRETYPE_FIXED64, '<q')
|
| 620 |
+
FloatEncoder = _FloatingPointEncoder(wire_format.WIRETYPE_FIXED32, '<f')
|
| 621 |
+
DoubleEncoder = _FloatingPointEncoder(wire_format.WIRETYPE_FIXED64, '<d')
|
| 622 |
+
|
| 623 |
+
|
| 624 |
+
def BoolEncoder(field_number, is_repeated, is_packed):
|
| 625 |
+
"""Returns an encoder for a boolean field."""
|
| 626 |
+
|
| 627 |
+
false_byte = b'\x00'
|
| 628 |
+
true_byte = b'\x01'
|
| 629 |
+
if is_packed:
|
| 630 |
+
tag_bytes = TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED)
|
| 631 |
+
local_EncodeVarint = _EncodeVarint
|
| 632 |
+
def EncodePackedField(write, value, deterministic):
|
| 633 |
+
write(tag_bytes)
|
| 634 |
+
local_EncodeVarint(write, len(value), deterministic)
|
| 635 |
+
for element in value:
|
| 636 |
+
if element:
|
| 637 |
+
write(true_byte)
|
| 638 |
+
else:
|
| 639 |
+
write(false_byte)
|
| 640 |
+
return EncodePackedField
|
| 641 |
+
elif is_repeated:
|
| 642 |
+
tag_bytes = TagBytes(field_number, wire_format.WIRETYPE_VARINT)
|
| 643 |
+
def EncodeRepeatedField(write, value, unused_deterministic=None):
|
| 644 |
+
for element in value:
|
| 645 |
+
write(tag_bytes)
|
| 646 |
+
if element:
|
| 647 |
+
write(true_byte)
|
| 648 |
+
else:
|
| 649 |
+
write(false_byte)
|
| 650 |
+
return EncodeRepeatedField
|
| 651 |
+
else:
|
| 652 |
+
tag_bytes = TagBytes(field_number, wire_format.WIRETYPE_VARINT)
|
| 653 |
+
def EncodeField(write, value, unused_deterministic=None):
|
| 654 |
+
write(tag_bytes)
|
| 655 |
+
if value:
|
| 656 |
+
return write(true_byte)
|
| 657 |
+
return write(false_byte)
|
| 658 |
+
return EncodeField
|
| 659 |
+
|
| 660 |
+
|
| 661 |
+
def StringEncoder(field_number, is_repeated, is_packed):
|
| 662 |
+
"""Returns an encoder for a string field."""
|
| 663 |
+
|
| 664 |
+
tag = TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED)
|
| 665 |
+
local_EncodeVarint = _EncodeVarint
|
| 666 |
+
local_len = len
|
| 667 |
+
assert not is_packed
|
| 668 |
+
if is_repeated:
|
| 669 |
+
def EncodeRepeatedField(write, value, deterministic):
|
| 670 |
+
for element in value:
|
| 671 |
+
encoded = element.encode('utf-8')
|
| 672 |
+
write(tag)
|
| 673 |
+
local_EncodeVarint(write, local_len(encoded), deterministic)
|
| 674 |
+
write(encoded)
|
| 675 |
+
return EncodeRepeatedField
|
| 676 |
+
else:
|
| 677 |
+
def EncodeField(write, value, deterministic):
|
| 678 |
+
encoded = value.encode('utf-8')
|
| 679 |
+
write(tag)
|
| 680 |
+
local_EncodeVarint(write, local_len(encoded), deterministic)
|
| 681 |
+
return write(encoded)
|
| 682 |
+
return EncodeField
|
| 683 |
+
|
| 684 |
+
|
| 685 |
+
def BytesEncoder(field_number, is_repeated, is_packed):
|
| 686 |
+
"""Returns an encoder for a bytes field."""
|
| 687 |
+
|
| 688 |
+
tag = TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED)
|
| 689 |
+
local_EncodeVarint = _EncodeVarint
|
| 690 |
+
local_len = len
|
| 691 |
+
assert not is_packed
|
| 692 |
+
if is_repeated:
|
| 693 |
+
def EncodeRepeatedField(write, value, deterministic):
|
| 694 |
+
for element in value:
|
| 695 |
+
write(tag)
|
| 696 |
+
local_EncodeVarint(write, local_len(element), deterministic)
|
| 697 |
+
write(element)
|
| 698 |
+
return EncodeRepeatedField
|
| 699 |
+
else:
|
| 700 |
+
def EncodeField(write, value, deterministic):
|
| 701 |
+
write(tag)
|
| 702 |
+
local_EncodeVarint(write, local_len(value), deterministic)
|
| 703 |
+
return write(value)
|
| 704 |
+
return EncodeField
|
| 705 |
+
|
| 706 |
+
|
| 707 |
+
def GroupEncoder(field_number, is_repeated, is_packed):
|
| 708 |
+
"""Returns an encoder for a group field."""
|
| 709 |
+
|
| 710 |
+
start_tag = TagBytes(field_number, wire_format.WIRETYPE_START_GROUP)
|
| 711 |
+
end_tag = TagBytes(field_number, wire_format.WIRETYPE_END_GROUP)
|
| 712 |
+
assert not is_packed
|
| 713 |
+
if is_repeated:
|
| 714 |
+
def EncodeRepeatedField(write, value, deterministic):
|
| 715 |
+
for element in value:
|
| 716 |
+
write(start_tag)
|
| 717 |
+
element._InternalSerialize(write, deterministic)
|
| 718 |
+
write(end_tag)
|
| 719 |
+
return EncodeRepeatedField
|
| 720 |
+
else:
|
| 721 |
+
def EncodeField(write, value, deterministic):
|
| 722 |
+
write(start_tag)
|
| 723 |
+
value._InternalSerialize(write, deterministic)
|
| 724 |
+
return write(end_tag)
|
| 725 |
+
return EncodeField
|
| 726 |
+
|
| 727 |
+
|
| 728 |
+
def MessageEncoder(field_number, is_repeated, is_packed):
|
| 729 |
+
"""Returns an encoder for a message field."""
|
| 730 |
+
|
| 731 |
+
tag = TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED)
|
| 732 |
+
local_EncodeVarint = _EncodeVarint
|
| 733 |
+
assert not is_packed
|
| 734 |
+
if is_repeated:
|
| 735 |
+
def EncodeRepeatedField(write, value, deterministic):
|
| 736 |
+
for element in value:
|
| 737 |
+
write(tag)
|
| 738 |
+
local_EncodeVarint(write, element.ByteSize(), deterministic)
|
| 739 |
+
element._InternalSerialize(write, deterministic)
|
| 740 |
+
return EncodeRepeatedField
|
| 741 |
+
else:
|
| 742 |
+
def EncodeField(write, value, deterministic):
|
| 743 |
+
write(tag)
|
| 744 |
+
local_EncodeVarint(write, value.ByteSize(), deterministic)
|
| 745 |
+
return value._InternalSerialize(write, deterministic)
|
| 746 |
+
return EncodeField
|
| 747 |
+
|
| 748 |
+
|
| 749 |
+
# --------------------------------------------------------------------
|
| 750 |
+
# As before, MessageSet is special.
|
| 751 |
+
|
| 752 |
+
|
| 753 |
+
def MessageSetItemEncoder(field_number):
|
| 754 |
+
"""Encoder for extensions of MessageSet.
|
| 755 |
+
|
| 756 |
+
The message set message looks like this:
|
| 757 |
+
message MessageSet {
|
| 758 |
+
repeated group Item = 1 {
|
| 759 |
+
required int32 type_id = 2;
|
| 760 |
+
required string message = 3;
|
| 761 |
+
}
|
| 762 |
+
}
|
| 763 |
+
"""
|
| 764 |
+
start_bytes = b"".join([
|
| 765 |
+
TagBytes(1, wire_format.WIRETYPE_START_GROUP),
|
| 766 |
+
TagBytes(2, wire_format.WIRETYPE_VARINT),
|
| 767 |
+
_VarintBytes(field_number),
|
| 768 |
+
TagBytes(3, wire_format.WIRETYPE_LENGTH_DELIMITED)])
|
| 769 |
+
end_bytes = TagBytes(1, wire_format.WIRETYPE_END_GROUP)
|
| 770 |
+
local_EncodeVarint = _EncodeVarint
|
| 771 |
+
|
| 772 |
+
def EncodeField(write, value, deterministic):
|
| 773 |
+
write(start_bytes)
|
| 774 |
+
local_EncodeVarint(write, value.ByteSize(), deterministic)
|
| 775 |
+
value._InternalSerialize(write, deterministic)
|
| 776 |
+
return write(end_bytes)
|
| 777 |
+
|
| 778 |
+
return EncodeField
|
| 779 |
+
|
| 780 |
+
|
| 781 |
+
# --------------------------------------------------------------------
|
| 782 |
+
# As before, Map is special.
|
| 783 |
+
|
| 784 |
+
|
| 785 |
+
def MapEncoder(field_descriptor):
|
| 786 |
+
"""Encoder for extensions of MessageSet.
|
| 787 |
+
|
| 788 |
+
Maps always have a wire format like this:
|
| 789 |
+
message MapEntry {
|
| 790 |
+
key_type key = 1;
|
| 791 |
+
value_type value = 2;
|
| 792 |
+
}
|
| 793 |
+
repeated MapEntry map = N;
|
| 794 |
+
"""
|
| 795 |
+
# Can't look at field_descriptor.message_type._concrete_class because it may
|
| 796 |
+
# not have been initialized yet.
|
| 797 |
+
message_type = field_descriptor.message_type
|
| 798 |
+
encode_message = MessageEncoder(field_descriptor.number, False, False)
|
| 799 |
+
|
| 800 |
+
def EncodeField(write, value, deterministic):
|
| 801 |
+
value_keys = sorted(value.keys()) if deterministic else value
|
| 802 |
+
for key in value_keys:
|
| 803 |
+
entry_msg = message_type._concrete_class(key=key, value=value[key])
|
| 804 |
+
encode_message(write, entry_msg, deterministic)
|
| 805 |
+
|
| 806 |
+
return EncodeField
|
.venv/lib/python3.11/site-packages/google/protobuf/internal/python_edition_defaults.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
This file contains the serialized FeatureSetDefaults object corresponding to
|
| 3 |
+
the Pure Python runtime. This is used for feature resolution under Editions.
|
| 4 |
+
"""
|
| 5 |
+
_PROTOBUF_INTERNAL_PYTHON_EDITION_DEFAULTS = b"\n\023\030\204\007\"\000*\014\010\001\020\002\030\002 \003(\0010\002\n\023\030\347\007\"\000*\014\010\002\020\001\030\001 \002(\0010\001\n\023\030\350\007\"\014\010\001\020\001\030\001 \002(\0010\001*\000 \346\007(\350\007"
|
.venv/lib/python3.11/site-packages/google/protobuf/internal/testing_refleaks.py
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Protocol Buffers - Google's data interchange format
|
| 2 |
+
# Copyright 2008 Google Inc. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Use of this source code is governed by a BSD-style
|
| 5 |
+
# license that can be found in the LICENSE file or at
|
| 6 |
+
# https://developers.google.com/open-source/licenses/bsd
|
| 7 |
+
|
| 8 |
+
"""A subclass of unittest.TestCase which checks for reference leaks.
|
| 9 |
+
|
| 10 |
+
To use:
|
| 11 |
+
- Use testing_refleak.BaseTestCase instead of unittest.TestCase
|
| 12 |
+
- Configure and compile Python with --with-pydebug
|
| 13 |
+
|
| 14 |
+
If sys.gettotalrefcount() is not available (because Python was built without
|
| 15 |
+
the Py_DEBUG option), then this module is a no-op and tests will run normally.
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
import copyreg
|
| 19 |
+
import gc
|
| 20 |
+
import sys
|
| 21 |
+
import unittest
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class LocalTestResult(unittest.TestResult):
|
| 25 |
+
"""A TestResult which forwards events to a parent object, except for Skips."""
|
| 26 |
+
|
| 27 |
+
def __init__(self, parent_result):
|
| 28 |
+
unittest.TestResult.__init__(self)
|
| 29 |
+
self.parent_result = parent_result
|
| 30 |
+
|
| 31 |
+
def addError(self, test, error):
|
| 32 |
+
self.parent_result.addError(test, error)
|
| 33 |
+
|
| 34 |
+
def addFailure(self, test, error):
|
| 35 |
+
self.parent_result.addFailure(test, error)
|
| 36 |
+
|
| 37 |
+
def addSkip(self, test, reason):
|
| 38 |
+
pass
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class ReferenceLeakCheckerMixin(object):
|
| 42 |
+
"""A mixin class for TestCase, which checks reference counts."""
|
| 43 |
+
|
| 44 |
+
NB_RUNS = 3
|
| 45 |
+
|
| 46 |
+
def run(self, result=None):
|
| 47 |
+
testMethod = getattr(self, self._testMethodName)
|
| 48 |
+
expecting_failure_method = getattr(testMethod, "__unittest_expecting_failure__", False)
|
| 49 |
+
expecting_failure_class = getattr(self, "__unittest_expecting_failure__", False)
|
| 50 |
+
if expecting_failure_class or expecting_failure_method:
|
| 51 |
+
return
|
| 52 |
+
|
| 53 |
+
# python_message.py registers all Message classes to some pickle global
|
| 54 |
+
# registry, which makes the classes immortal.
|
| 55 |
+
# We save a copy of this registry, and reset it before we could references.
|
| 56 |
+
self._saved_pickle_registry = copyreg.dispatch_table.copy()
|
| 57 |
+
|
| 58 |
+
# Run the test twice, to warm up the instance attributes.
|
| 59 |
+
super(ReferenceLeakCheckerMixin, self).run(result=result)
|
| 60 |
+
super(ReferenceLeakCheckerMixin, self).run(result=result)
|
| 61 |
+
|
| 62 |
+
oldrefcount = 0
|
| 63 |
+
local_result = LocalTestResult(result)
|
| 64 |
+
num_flakes = 0
|
| 65 |
+
|
| 66 |
+
refcount_deltas = []
|
| 67 |
+
while len(refcount_deltas) < self.NB_RUNS:
|
| 68 |
+
oldrefcount = self._getRefcounts()
|
| 69 |
+
super(ReferenceLeakCheckerMixin, self).run(result=local_result)
|
| 70 |
+
newrefcount = self._getRefcounts()
|
| 71 |
+
# If the GC was able to collect some objects after the call to run() that
|
| 72 |
+
# it could not collect before the call, then the counts won't match.
|
| 73 |
+
if newrefcount < oldrefcount and num_flakes < 2:
|
| 74 |
+
# This result is (probably) a flake -- garbage collectors aren't very
|
| 75 |
+
# predictable, but a lower ending refcount is the opposite of the
|
| 76 |
+
# failure we are testing for. If the result is repeatable, then we will
|
| 77 |
+
# eventually report it, but not after trying to eliminate it.
|
| 78 |
+
num_flakes += 1
|
| 79 |
+
continue
|
| 80 |
+
num_flakes = 0
|
| 81 |
+
refcount_deltas.append(newrefcount - oldrefcount)
|
| 82 |
+
print(refcount_deltas, self)
|
| 83 |
+
|
| 84 |
+
try:
|
| 85 |
+
self.assertEqual(refcount_deltas, [0] * self.NB_RUNS)
|
| 86 |
+
except Exception: # pylint: disable=broad-except
|
| 87 |
+
result.addError(self, sys.exc_info())
|
| 88 |
+
|
| 89 |
+
def _getRefcounts(self):
|
| 90 |
+
copyreg.dispatch_table.clear()
|
| 91 |
+
copyreg.dispatch_table.update(self._saved_pickle_registry)
|
| 92 |
+
# It is sometimes necessary to gc.collect() multiple times, to ensure
|
| 93 |
+
# that all objects can be collected.
|
| 94 |
+
gc.collect()
|
| 95 |
+
gc.collect()
|
| 96 |
+
gc.collect()
|
| 97 |
+
return sys.gettotalrefcount()
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
if hasattr(sys, 'gettotalrefcount'):
|
| 101 |
+
|
| 102 |
+
def TestCase(test_class):
|
| 103 |
+
new_bases = (ReferenceLeakCheckerMixin,) + test_class.__bases__
|
| 104 |
+
new_class = type(test_class)(
|
| 105 |
+
test_class.__name__, new_bases, dict(test_class.__dict__))
|
| 106 |
+
return new_class
|
| 107 |
+
SkipReferenceLeakChecker = unittest.skip
|
| 108 |
+
|
| 109 |
+
else:
|
| 110 |
+
# When PyDEBUG is not enabled, run the tests normally.
|
| 111 |
+
|
| 112 |
+
def TestCase(test_class):
|
| 113 |
+
return test_class
|
| 114 |
+
|
| 115 |
+
def SkipReferenceLeakChecker(reason):
|
| 116 |
+
del reason # Don't skip, so don't need a reason.
|
| 117 |
+
def Same(func):
|
| 118 |
+
return func
|
| 119 |
+
return Same
|
.venv/lib/python3.11/site-packages/google/protobuf/internal/well_known_types.py
ADDED
|
@@ -0,0 +1,678 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Protocol Buffers - Google's data interchange format
|
| 2 |
+
# Copyright 2008 Google Inc. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Use of this source code is governed by a BSD-style
|
| 5 |
+
# license that can be found in the LICENSE file or at
|
| 6 |
+
# https://developers.google.com/open-source/licenses/bsd
|
| 7 |
+
|
| 8 |
+
"""Contains well known classes.
|
| 9 |
+
|
| 10 |
+
This files defines well known classes which need extra maintenance including:
|
| 11 |
+
- Any
|
| 12 |
+
- Duration
|
| 13 |
+
- FieldMask
|
| 14 |
+
- Struct
|
| 15 |
+
- Timestamp
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
__author__ = 'jieluo@google.com (Jie Luo)'
|
| 19 |
+
|
| 20 |
+
import calendar
|
| 21 |
+
import collections.abc
|
| 22 |
+
import datetime
|
| 23 |
+
import warnings
|
| 24 |
+
from google.protobuf.internal import field_mask
|
| 25 |
+
from typing import Union
|
| 26 |
+
|
| 27 |
+
FieldMask = field_mask.FieldMask
|
| 28 |
+
|
| 29 |
+
_TIMESTAMPFOMAT = '%Y-%m-%dT%H:%M:%S'
|
| 30 |
+
_NANOS_PER_SECOND = 1000000000
|
| 31 |
+
_NANOS_PER_MILLISECOND = 1000000
|
| 32 |
+
_NANOS_PER_MICROSECOND = 1000
|
| 33 |
+
_MILLIS_PER_SECOND = 1000
|
| 34 |
+
_MICROS_PER_SECOND = 1000000
|
| 35 |
+
_SECONDS_PER_DAY = 24 * 3600
|
| 36 |
+
_DURATION_SECONDS_MAX = 315576000000
|
| 37 |
+
_TIMESTAMP_SECONDS_MIN = -62135596800
|
| 38 |
+
_TIMESTAMP_SECONDS_MAX = 253402300799
|
| 39 |
+
|
| 40 |
+
_EPOCH_DATETIME_NAIVE = datetime.datetime(1970, 1, 1, tzinfo=None)
|
| 41 |
+
_EPOCH_DATETIME_AWARE = _EPOCH_DATETIME_NAIVE.replace(
|
| 42 |
+
tzinfo=datetime.timezone.utc
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class Any(object):
|
| 47 |
+
"""Class for Any Message type."""
|
| 48 |
+
|
| 49 |
+
__slots__ = ()
|
| 50 |
+
|
| 51 |
+
def Pack(self, msg, type_url_prefix='type.googleapis.com/',
|
| 52 |
+
deterministic=None):
|
| 53 |
+
"""Packs the specified message into current Any message."""
|
| 54 |
+
if len(type_url_prefix) < 1 or type_url_prefix[-1] != '/':
|
| 55 |
+
self.type_url = '%s/%s' % (type_url_prefix, msg.DESCRIPTOR.full_name)
|
| 56 |
+
else:
|
| 57 |
+
self.type_url = '%s%s' % (type_url_prefix, msg.DESCRIPTOR.full_name)
|
| 58 |
+
self.value = msg.SerializeToString(deterministic=deterministic)
|
| 59 |
+
|
| 60 |
+
def Unpack(self, msg):
|
| 61 |
+
"""Unpacks the current Any message into specified message."""
|
| 62 |
+
descriptor = msg.DESCRIPTOR
|
| 63 |
+
if not self.Is(descriptor):
|
| 64 |
+
return False
|
| 65 |
+
msg.ParseFromString(self.value)
|
| 66 |
+
return True
|
| 67 |
+
|
| 68 |
+
def TypeName(self):
|
| 69 |
+
"""Returns the protobuf type name of the inner message."""
|
| 70 |
+
# Only last part is to be used: b/25630112
|
| 71 |
+
return self.type_url.split('/')[-1]
|
| 72 |
+
|
| 73 |
+
def Is(self, descriptor):
|
| 74 |
+
"""Checks if this Any represents the given protobuf type."""
|
| 75 |
+
return '/' in self.type_url and self.TypeName() == descriptor.full_name
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
class Timestamp(object):
|
| 79 |
+
"""Class for Timestamp message type."""
|
| 80 |
+
|
| 81 |
+
__slots__ = ()
|
| 82 |
+
|
| 83 |
+
def ToJsonString(self):
|
| 84 |
+
"""Converts Timestamp to RFC 3339 date string format.
|
| 85 |
+
|
| 86 |
+
Returns:
|
| 87 |
+
A string converted from timestamp. The string is always Z-normalized
|
| 88 |
+
and uses 3, 6 or 9 fractional digits as required to represent the
|
| 89 |
+
exact time. Example of the return format: '1972-01-01T10:00:20.021Z'
|
| 90 |
+
"""
|
| 91 |
+
_CheckTimestampValid(self.seconds, self.nanos)
|
| 92 |
+
nanos = self.nanos
|
| 93 |
+
seconds = self.seconds % _SECONDS_PER_DAY
|
| 94 |
+
days = (self.seconds - seconds) // _SECONDS_PER_DAY
|
| 95 |
+
dt = datetime.datetime(1970, 1, 1) + datetime.timedelta(days, seconds)
|
| 96 |
+
|
| 97 |
+
result = dt.isoformat()
|
| 98 |
+
if (nanos % 1e9) == 0:
|
| 99 |
+
# If there are 0 fractional digits, the fractional
|
| 100 |
+
# point '.' should be omitted when serializing.
|
| 101 |
+
return result + 'Z'
|
| 102 |
+
if (nanos % 1e6) == 0:
|
| 103 |
+
# Serialize 3 fractional digits.
|
| 104 |
+
return result + '.%03dZ' % (nanos / 1e6)
|
| 105 |
+
if (nanos % 1e3) == 0:
|
| 106 |
+
# Serialize 6 fractional digits.
|
| 107 |
+
return result + '.%06dZ' % (nanos / 1e3)
|
| 108 |
+
# Serialize 9 fractional digits.
|
| 109 |
+
return result + '.%09dZ' % nanos
|
| 110 |
+
|
| 111 |
+
def FromJsonString(self, value):
|
| 112 |
+
"""Parse a RFC 3339 date string format to Timestamp.
|
| 113 |
+
|
| 114 |
+
Args:
|
| 115 |
+
value: A date string. Any fractional digits (or none) and any offset are
|
| 116 |
+
accepted as long as they fit into nano-seconds precision.
|
| 117 |
+
Example of accepted format: '1972-01-01T10:00:20.021-05:00'
|
| 118 |
+
|
| 119 |
+
Raises:
|
| 120 |
+
ValueError: On parsing problems.
|
| 121 |
+
"""
|
| 122 |
+
if not isinstance(value, str):
|
| 123 |
+
raise ValueError('Timestamp JSON value not a string: {!r}'.format(value))
|
| 124 |
+
timezone_offset = value.find('Z')
|
| 125 |
+
if timezone_offset == -1:
|
| 126 |
+
timezone_offset = value.find('+')
|
| 127 |
+
if timezone_offset == -1:
|
| 128 |
+
timezone_offset = value.rfind('-')
|
| 129 |
+
if timezone_offset == -1:
|
| 130 |
+
raise ValueError(
|
| 131 |
+
'Failed to parse timestamp: missing valid timezone offset.')
|
| 132 |
+
time_value = value[0:timezone_offset]
|
| 133 |
+
# Parse datetime and nanos.
|
| 134 |
+
point_position = time_value.find('.')
|
| 135 |
+
if point_position == -1:
|
| 136 |
+
second_value = time_value
|
| 137 |
+
nano_value = ''
|
| 138 |
+
else:
|
| 139 |
+
second_value = time_value[:point_position]
|
| 140 |
+
nano_value = time_value[point_position + 1:]
|
| 141 |
+
if 't' in second_value:
|
| 142 |
+
raise ValueError(
|
| 143 |
+
'time data \'{0}\' does not match format \'%Y-%m-%dT%H:%M:%S\', '
|
| 144 |
+
'lowercase \'t\' is not accepted'.format(second_value))
|
| 145 |
+
date_object = datetime.datetime.strptime(second_value, _TIMESTAMPFOMAT)
|
| 146 |
+
td = date_object - datetime.datetime(1970, 1, 1)
|
| 147 |
+
seconds = td.seconds + td.days * _SECONDS_PER_DAY
|
| 148 |
+
if len(nano_value) > 9:
|
| 149 |
+
raise ValueError(
|
| 150 |
+
'Failed to parse Timestamp: nanos {0} more than '
|
| 151 |
+
'9 fractional digits.'.format(nano_value))
|
| 152 |
+
if nano_value:
|
| 153 |
+
nanos = round(float('0.' + nano_value) * 1e9)
|
| 154 |
+
else:
|
| 155 |
+
nanos = 0
|
| 156 |
+
# Parse timezone offsets.
|
| 157 |
+
if value[timezone_offset] == 'Z':
|
| 158 |
+
if len(value) != timezone_offset + 1:
|
| 159 |
+
raise ValueError('Failed to parse timestamp: invalid trailing'
|
| 160 |
+
' data {0}.'.format(value))
|
| 161 |
+
else:
|
| 162 |
+
timezone = value[timezone_offset:]
|
| 163 |
+
pos = timezone.find(':')
|
| 164 |
+
if pos == -1:
|
| 165 |
+
raise ValueError(
|
| 166 |
+
'Invalid timezone offset value: {0}.'.format(timezone))
|
| 167 |
+
if timezone[0] == '+':
|
| 168 |
+
seconds -= (int(timezone[1:pos])*60+int(timezone[pos+1:]))*60
|
| 169 |
+
else:
|
| 170 |
+
seconds += (int(timezone[1:pos])*60+int(timezone[pos+1:]))*60
|
| 171 |
+
# Set seconds and nanos
|
| 172 |
+
_CheckTimestampValid(seconds, nanos)
|
| 173 |
+
self.seconds = int(seconds)
|
| 174 |
+
self.nanos = int(nanos)
|
| 175 |
+
|
| 176 |
+
def GetCurrentTime(self):
|
| 177 |
+
"""Get the current UTC into Timestamp."""
|
| 178 |
+
self.FromDatetime(datetime.datetime.utcnow())
|
| 179 |
+
|
| 180 |
+
def ToNanoseconds(self):
|
| 181 |
+
"""Converts Timestamp to nanoseconds since epoch."""
|
| 182 |
+
_CheckTimestampValid(self.seconds, self.nanos)
|
| 183 |
+
return self.seconds * _NANOS_PER_SECOND + self.nanos
|
| 184 |
+
|
| 185 |
+
def ToMicroseconds(self):
|
| 186 |
+
"""Converts Timestamp to microseconds since epoch."""
|
| 187 |
+
_CheckTimestampValid(self.seconds, self.nanos)
|
| 188 |
+
return (self.seconds * _MICROS_PER_SECOND +
|
| 189 |
+
self.nanos // _NANOS_PER_MICROSECOND)
|
| 190 |
+
|
| 191 |
+
def ToMilliseconds(self):
|
| 192 |
+
"""Converts Timestamp to milliseconds since epoch."""
|
| 193 |
+
_CheckTimestampValid(self.seconds, self.nanos)
|
| 194 |
+
return (self.seconds * _MILLIS_PER_SECOND +
|
| 195 |
+
self.nanos // _NANOS_PER_MILLISECOND)
|
| 196 |
+
|
| 197 |
+
def ToSeconds(self):
|
| 198 |
+
"""Converts Timestamp to seconds since epoch."""
|
| 199 |
+
_CheckTimestampValid(self.seconds, self.nanos)
|
| 200 |
+
return self.seconds
|
| 201 |
+
|
| 202 |
+
def FromNanoseconds(self, nanos):
|
| 203 |
+
"""Converts nanoseconds since epoch to Timestamp."""
|
| 204 |
+
seconds = nanos // _NANOS_PER_SECOND
|
| 205 |
+
nanos = nanos % _NANOS_PER_SECOND
|
| 206 |
+
_CheckTimestampValid(seconds, nanos)
|
| 207 |
+
self.seconds = seconds
|
| 208 |
+
self.nanos = nanos
|
| 209 |
+
|
| 210 |
+
def FromMicroseconds(self, micros):
|
| 211 |
+
"""Converts microseconds since epoch to Timestamp."""
|
| 212 |
+
seconds = micros // _MICROS_PER_SECOND
|
| 213 |
+
nanos = (micros % _MICROS_PER_SECOND) * _NANOS_PER_MICROSECOND
|
| 214 |
+
_CheckTimestampValid(seconds, nanos)
|
| 215 |
+
self.seconds = seconds
|
| 216 |
+
self.nanos = nanos
|
| 217 |
+
|
| 218 |
+
def FromMilliseconds(self, millis):
|
| 219 |
+
"""Converts milliseconds since epoch to Timestamp."""
|
| 220 |
+
seconds = millis // _MILLIS_PER_SECOND
|
| 221 |
+
nanos = (millis % _MILLIS_PER_SECOND) * _NANOS_PER_MILLISECOND
|
| 222 |
+
_CheckTimestampValid(seconds, nanos)
|
| 223 |
+
self.seconds = seconds
|
| 224 |
+
self.nanos = nanos
|
| 225 |
+
|
| 226 |
+
def FromSeconds(self, seconds):
|
| 227 |
+
"""Converts seconds since epoch to Timestamp."""
|
| 228 |
+
_CheckTimestampValid(seconds, 0)
|
| 229 |
+
self.seconds = seconds
|
| 230 |
+
self.nanos = 0
|
| 231 |
+
|
| 232 |
+
def ToDatetime(self, tzinfo=None):
|
| 233 |
+
"""Converts Timestamp to a datetime.
|
| 234 |
+
|
| 235 |
+
Args:
|
| 236 |
+
tzinfo: A datetime.tzinfo subclass; defaults to None.
|
| 237 |
+
|
| 238 |
+
Returns:
|
| 239 |
+
If tzinfo is None, returns a timezone-naive UTC datetime (with no timezone
|
| 240 |
+
information, i.e. not aware that it's UTC).
|
| 241 |
+
|
| 242 |
+
Otherwise, returns a timezone-aware datetime in the input timezone.
|
| 243 |
+
"""
|
| 244 |
+
# Using datetime.fromtimestamp for this would avoid constructing an extra
|
| 245 |
+
# timedelta object and possibly an extra datetime. Unfortunately, that has
|
| 246 |
+
# the disadvantage of not handling the full precision (on all platforms, see
|
| 247 |
+
# https://github.com/python/cpython/issues/109849) or full range (on some
|
| 248 |
+
# platforms, see https://github.com/python/cpython/issues/110042) of
|
| 249 |
+
# datetime.
|
| 250 |
+
_CheckTimestampValid(self.seconds, self.nanos)
|
| 251 |
+
delta = datetime.timedelta(
|
| 252 |
+
seconds=self.seconds,
|
| 253 |
+
microseconds=_RoundTowardZero(self.nanos, _NANOS_PER_MICROSECOND),
|
| 254 |
+
)
|
| 255 |
+
if tzinfo is None:
|
| 256 |
+
return _EPOCH_DATETIME_NAIVE + delta
|
| 257 |
+
else:
|
| 258 |
+
# Note the tz conversion has to come after the timedelta arithmetic.
|
| 259 |
+
return (_EPOCH_DATETIME_AWARE + delta).astimezone(tzinfo)
|
| 260 |
+
|
| 261 |
+
def FromDatetime(self, dt):
|
| 262 |
+
"""Converts datetime to Timestamp.
|
| 263 |
+
|
| 264 |
+
Args:
|
| 265 |
+
dt: A datetime. If it's timezone-naive, it's assumed to be in UTC.
|
| 266 |
+
"""
|
| 267 |
+
# Using this guide: http://wiki.python.org/moin/WorkingWithTime
|
| 268 |
+
# And this conversion guide: http://docs.python.org/library/time.html
|
| 269 |
+
|
| 270 |
+
# Turn the date parameter into a tuple (struct_time) that can then be
|
| 271 |
+
# manipulated into a long value of seconds. During the conversion from
|
| 272 |
+
# struct_time to long, the source date in UTC, and so it follows that the
|
| 273 |
+
# correct transformation is calendar.timegm()
|
| 274 |
+
try:
|
| 275 |
+
seconds = calendar.timegm(dt.utctimetuple())
|
| 276 |
+
nanos = dt.microsecond * _NANOS_PER_MICROSECOND
|
| 277 |
+
except AttributeError as e:
|
| 278 |
+
raise AttributeError(
|
| 279 |
+
'Fail to convert to Timestamp. Expected a datetime like '
|
| 280 |
+
'object got {0} : {1}'.format(type(dt).__name__, e)
|
| 281 |
+
) from e
|
| 282 |
+
_CheckTimestampValid(seconds, nanos)
|
| 283 |
+
self.seconds = seconds
|
| 284 |
+
self.nanos = nanos
|
| 285 |
+
|
| 286 |
+
def _internal_assign(self, dt):
|
| 287 |
+
self.FromDatetime(dt)
|
| 288 |
+
|
| 289 |
+
def __add__(self, value) -> datetime.datetime:
|
| 290 |
+
if isinstance(value, Duration):
|
| 291 |
+
return self.ToDatetime() + value.ToTimedelta()
|
| 292 |
+
return self.ToDatetime() + value
|
| 293 |
+
|
| 294 |
+
__radd__ = __add__
|
| 295 |
+
|
| 296 |
+
def __sub__(self, value) -> Union[datetime.datetime, datetime.timedelta]:
|
| 297 |
+
if isinstance(value, Timestamp):
|
| 298 |
+
return self.ToDatetime() - value.ToDatetime()
|
| 299 |
+
elif isinstance(value, Duration):
|
| 300 |
+
return self.ToDatetime() - value.ToTimedelta()
|
| 301 |
+
return self.ToDatetime() - value
|
| 302 |
+
|
| 303 |
+
def __rsub__(self, dt) -> datetime.timedelta:
|
| 304 |
+
return dt - self.ToDatetime()
|
| 305 |
+
|
| 306 |
+
|
| 307 |
+
def _CheckTimestampValid(seconds, nanos):
|
| 308 |
+
if seconds < _TIMESTAMP_SECONDS_MIN or seconds > _TIMESTAMP_SECONDS_MAX:
|
| 309 |
+
raise ValueError(
|
| 310 |
+
'Timestamp is not valid: Seconds {0} must be in range '
|
| 311 |
+
'[-62135596800, 253402300799].'.format(seconds))
|
| 312 |
+
if nanos < 0 or nanos >= _NANOS_PER_SECOND:
|
| 313 |
+
raise ValueError(
|
| 314 |
+
'Timestamp is not valid: Nanos {} must be in a range '
|
| 315 |
+
'[0, 999999].'.format(nanos))
|
| 316 |
+
|
| 317 |
+
|
| 318 |
+
class Duration(object):
|
| 319 |
+
"""Class for Duration message type."""
|
| 320 |
+
|
| 321 |
+
__slots__ = ()
|
| 322 |
+
|
| 323 |
+
def ToJsonString(self):
|
| 324 |
+
"""Converts Duration to string format.
|
| 325 |
+
|
| 326 |
+
Returns:
|
| 327 |
+
A string converted from self. The string format will contains
|
| 328 |
+
3, 6, or 9 fractional digits depending on the precision required to
|
| 329 |
+
represent the exact Duration value. For example: "1s", "1.010s",
|
| 330 |
+
"1.000000100s", "-3.100s"
|
| 331 |
+
"""
|
| 332 |
+
_CheckDurationValid(self.seconds, self.nanos)
|
| 333 |
+
if self.seconds < 0 or self.nanos < 0:
|
| 334 |
+
result = '-'
|
| 335 |
+
seconds = - self.seconds + int((0 - self.nanos) // 1e9)
|
| 336 |
+
nanos = (0 - self.nanos) % 1e9
|
| 337 |
+
else:
|
| 338 |
+
result = ''
|
| 339 |
+
seconds = self.seconds + int(self.nanos // 1e9)
|
| 340 |
+
nanos = self.nanos % 1e9
|
| 341 |
+
result += '%d' % seconds
|
| 342 |
+
if (nanos % 1e9) == 0:
|
| 343 |
+
# If there are 0 fractional digits, the fractional
|
| 344 |
+
# point '.' should be omitted when serializing.
|
| 345 |
+
return result + 's'
|
| 346 |
+
if (nanos % 1e6) == 0:
|
| 347 |
+
# Serialize 3 fractional digits.
|
| 348 |
+
return result + '.%03ds' % (nanos / 1e6)
|
| 349 |
+
if (nanos % 1e3) == 0:
|
| 350 |
+
# Serialize 6 fractional digits.
|
| 351 |
+
return result + '.%06ds' % (nanos / 1e3)
|
| 352 |
+
# Serialize 9 fractional digits.
|
| 353 |
+
return result + '.%09ds' % nanos
|
| 354 |
+
|
| 355 |
+
def FromJsonString(self, value):
|
| 356 |
+
"""Converts a string to Duration.
|
| 357 |
+
|
| 358 |
+
Args:
|
| 359 |
+
value: A string to be converted. The string must end with 's'. Any
|
| 360 |
+
fractional digits (or none) are accepted as long as they fit into
|
| 361 |
+
precision. For example: "1s", "1.01s", "1.0000001s", "-3.100s
|
| 362 |
+
|
| 363 |
+
Raises:
|
| 364 |
+
ValueError: On parsing problems.
|
| 365 |
+
"""
|
| 366 |
+
if not isinstance(value, str):
|
| 367 |
+
raise ValueError('Duration JSON value not a string: {!r}'.format(value))
|
| 368 |
+
if len(value) < 1 or value[-1] != 's':
|
| 369 |
+
raise ValueError(
|
| 370 |
+
'Duration must end with letter "s": {0}.'.format(value))
|
| 371 |
+
try:
|
| 372 |
+
pos = value.find('.')
|
| 373 |
+
if pos == -1:
|
| 374 |
+
seconds = int(value[:-1])
|
| 375 |
+
nanos = 0
|
| 376 |
+
else:
|
| 377 |
+
seconds = int(value[:pos])
|
| 378 |
+
if value[0] == '-':
|
| 379 |
+
nanos = int(round(float('-0{0}'.format(value[pos: -1])) *1e9))
|
| 380 |
+
else:
|
| 381 |
+
nanos = int(round(float('0{0}'.format(value[pos: -1])) *1e9))
|
| 382 |
+
_CheckDurationValid(seconds, nanos)
|
| 383 |
+
self.seconds = seconds
|
| 384 |
+
self.nanos = nanos
|
| 385 |
+
except ValueError as e:
|
| 386 |
+
raise ValueError(
|
| 387 |
+
'Couldn\'t parse duration: {0} : {1}.'.format(value, e))
|
| 388 |
+
|
| 389 |
+
def ToNanoseconds(self):
|
| 390 |
+
"""Converts a Duration to nanoseconds."""
|
| 391 |
+
return self.seconds * _NANOS_PER_SECOND + self.nanos
|
| 392 |
+
|
| 393 |
+
def ToMicroseconds(self):
|
| 394 |
+
"""Converts a Duration to microseconds."""
|
| 395 |
+
micros = _RoundTowardZero(self.nanos, _NANOS_PER_MICROSECOND)
|
| 396 |
+
return self.seconds * _MICROS_PER_SECOND + micros
|
| 397 |
+
|
| 398 |
+
def ToMilliseconds(self):
|
| 399 |
+
"""Converts a Duration to milliseconds."""
|
| 400 |
+
millis = _RoundTowardZero(self.nanos, _NANOS_PER_MILLISECOND)
|
| 401 |
+
return self.seconds * _MILLIS_PER_SECOND + millis
|
| 402 |
+
|
| 403 |
+
def ToSeconds(self):
|
| 404 |
+
"""Converts a Duration to seconds."""
|
| 405 |
+
return self.seconds
|
| 406 |
+
|
| 407 |
+
def FromNanoseconds(self, nanos):
|
| 408 |
+
"""Converts nanoseconds to Duration."""
|
| 409 |
+
self._NormalizeDuration(nanos // _NANOS_PER_SECOND,
|
| 410 |
+
nanos % _NANOS_PER_SECOND)
|
| 411 |
+
|
| 412 |
+
def FromMicroseconds(self, micros):
|
| 413 |
+
"""Converts microseconds to Duration."""
|
| 414 |
+
self._NormalizeDuration(
|
| 415 |
+
micros // _MICROS_PER_SECOND,
|
| 416 |
+
(micros % _MICROS_PER_SECOND) * _NANOS_PER_MICROSECOND)
|
| 417 |
+
|
| 418 |
+
def FromMilliseconds(self, millis):
|
| 419 |
+
"""Converts milliseconds to Duration."""
|
| 420 |
+
self._NormalizeDuration(
|
| 421 |
+
millis // _MILLIS_PER_SECOND,
|
| 422 |
+
(millis % _MILLIS_PER_SECOND) * _NANOS_PER_MILLISECOND)
|
| 423 |
+
|
| 424 |
+
def FromSeconds(self, seconds):
|
| 425 |
+
"""Converts seconds to Duration."""
|
| 426 |
+
self.seconds = seconds
|
| 427 |
+
self.nanos = 0
|
| 428 |
+
|
| 429 |
+
def ToTimedelta(self) -> datetime.timedelta:
|
| 430 |
+
"""Converts Duration to timedelta."""
|
| 431 |
+
return datetime.timedelta(
|
| 432 |
+
seconds=self.seconds, microseconds=_RoundTowardZero(
|
| 433 |
+
self.nanos, _NANOS_PER_MICROSECOND))
|
| 434 |
+
|
| 435 |
+
def FromTimedelta(self, td):
|
| 436 |
+
"""Converts timedelta to Duration."""
|
| 437 |
+
try:
|
| 438 |
+
self._NormalizeDuration(
|
| 439 |
+
td.seconds + td.days * _SECONDS_PER_DAY,
|
| 440 |
+
td.microseconds * _NANOS_PER_MICROSECOND,
|
| 441 |
+
)
|
| 442 |
+
except AttributeError as e:
|
| 443 |
+
raise AttributeError(
|
| 444 |
+
'Fail to convert to Duration. Expected a timedelta like '
|
| 445 |
+
'object got {0}: {1}'.format(type(td).__name__, e)
|
| 446 |
+
) from e
|
| 447 |
+
|
| 448 |
+
def _internal_assign(self, td):
|
| 449 |
+
self.FromTimedelta(td)
|
| 450 |
+
|
| 451 |
+
def _NormalizeDuration(self, seconds, nanos):
|
| 452 |
+
"""Set Duration by seconds and nanos."""
|
| 453 |
+
# Force nanos to be negative if the duration is negative.
|
| 454 |
+
if seconds < 0 and nanos > 0:
|
| 455 |
+
seconds += 1
|
| 456 |
+
nanos -= _NANOS_PER_SECOND
|
| 457 |
+
self.seconds = seconds
|
| 458 |
+
self.nanos = nanos
|
| 459 |
+
|
| 460 |
+
def __add__(self, value) -> Union[datetime.datetime, datetime.timedelta]:
|
| 461 |
+
if isinstance(value, Timestamp):
|
| 462 |
+
return self.ToTimedelta() + value.ToDatetime()
|
| 463 |
+
return self.ToTimedelta() + value
|
| 464 |
+
|
| 465 |
+
__radd__ = __add__
|
| 466 |
+
|
| 467 |
+
def __rsub__(self, dt) -> Union[datetime.datetime, datetime.timedelta]:
|
| 468 |
+
return dt - self.ToTimedelta()
|
| 469 |
+
|
| 470 |
+
|
| 471 |
+
def _CheckDurationValid(seconds, nanos):
|
| 472 |
+
if seconds < -_DURATION_SECONDS_MAX or seconds > _DURATION_SECONDS_MAX:
|
| 473 |
+
raise ValueError(
|
| 474 |
+
'Duration is not valid: Seconds {0} must be in range '
|
| 475 |
+
'[-315576000000, 315576000000].'.format(seconds))
|
| 476 |
+
if nanos <= -_NANOS_PER_SECOND or nanos >= _NANOS_PER_SECOND:
|
| 477 |
+
raise ValueError(
|
| 478 |
+
'Duration is not valid: Nanos {0} must be in range '
|
| 479 |
+
'[-999999999, 999999999].'.format(nanos))
|
| 480 |
+
if (nanos < 0 and seconds > 0) or (nanos > 0 and seconds < 0):
|
| 481 |
+
raise ValueError(
|
| 482 |
+
'Duration is not valid: Sign mismatch.')
|
| 483 |
+
|
| 484 |
+
|
| 485 |
+
def _RoundTowardZero(value, divider):
|
| 486 |
+
"""Truncates the remainder part after division."""
|
| 487 |
+
# For some languages, the sign of the remainder is implementation
|
| 488 |
+
# dependent if any of the operands is negative. Here we enforce
|
| 489 |
+
# "rounded toward zero" semantics. For example, for (-5) / 2 an
|
| 490 |
+
# implementation may give -3 as the result with the remainder being
|
| 491 |
+
# 1. This function ensures we always return -2 (closer to zero).
|
| 492 |
+
result = value // divider
|
| 493 |
+
remainder = value % divider
|
| 494 |
+
if result < 0 and remainder > 0:
|
| 495 |
+
return result + 1
|
| 496 |
+
else:
|
| 497 |
+
return result
|
| 498 |
+
|
| 499 |
+
|
| 500 |
+
def _SetStructValue(struct_value, value):
|
| 501 |
+
if value is None:
|
| 502 |
+
struct_value.null_value = 0
|
| 503 |
+
elif isinstance(value, bool):
|
| 504 |
+
# Note: this check must come before the number check because in Python
|
| 505 |
+
# True and False are also considered numbers.
|
| 506 |
+
struct_value.bool_value = value
|
| 507 |
+
elif isinstance(value, str):
|
| 508 |
+
struct_value.string_value = value
|
| 509 |
+
elif isinstance(value, (int, float)):
|
| 510 |
+
struct_value.number_value = value
|
| 511 |
+
elif isinstance(value, (dict, Struct)):
|
| 512 |
+
struct_value.struct_value.Clear()
|
| 513 |
+
struct_value.struct_value.update(value)
|
| 514 |
+
elif isinstance(value, (list, tuple, ListValue)):
|
| 515 |
+
struct_value.list_value.Clear()
|
| 516 |
+
struct_value.list_value.extend(value)
|
| 517 |
+
else:
|
| 518 |
+
raise ValueError('Unexpected type')
|
| 519 |
+
|
| 520 |
+
|
| 521 |
+
def _GetStructValue(struct_value):
|
| 522 |
+
which = struct_value.WhichOneof('kind')
|
| 523 |
+
if which == 'struct_value':
|
| 524 |
+
return struct_value.struct_value
|
| 525 |
+
elif which == 'null_value':
|
| 526 |
+
return None
|
| 527 |
+
elif which == 'number_value':
|
| 528 |
+
return struct_value.number_value
|
| 529 |
+
elif which == 'string_value':
|
| 530 |
+
return struct_value.string_value
|
| 531 |
+
elif which == 'bool_value':
|
| 532 |
+
return struct_value.bool_value
|
| 533 |
+
elif which == 'list_value':
|
| 534 |
+
return struct_value.list_value
|
| 535 |
+
elif which is None:
|
| 536 |
+
raise ValueError('Value not set')
|
| 537 |
+
|
| 538 |
+
|
| 539 |
+
class Struct(object):
|
| 540 |
+
"""Class for Struct message type."""
|
| 541 |
+
|
| 542 |
+
__slots__ = ()
|
| 543 |
+
|
| 544 |
+
def __getitem__(self, key):
|
| 545 |
+
return _GetStructValue(self.fields[key])
|
| 546 |
+
|
| 547 |
+
def __setitem__(self, key, value):
|
| 548 |
+
_SetStructValue(self.fields[key], value)
|
| 549 |
+
|
| 550 |
+
def __delitem__(self, key):
|
| 551 |
+
del self.fields[key]
|
| 552 |
+
|
| 553 |
+
def __len__(self):
|
| 554 |
+
return len(self.fields)
|
| 555 |
+
|
| 556 |
+
def __iter__(self):
|
| 557 |
+
return iter(self.fields)
|
| 558 |
+
|
| 559 |
+
def _internal_assign(self, dictionary):
|
| 560 |
+
self.Clear()
|
| 561 |
+
self.update(dictionary)
|
| 562 |
+
|
| 563 |
+
def _internal_compare(self, other):
|
| 564 |
+
size = len(self)
|
| 565 |
+
if size != len(other):
|
| 566 |
+
return False
|
| 567 |
+
for key, value in self.items():
|
| 568 |
+
if key not in other:
|
| 569 |
+
return False
|
| 570 |
+
if isinstance(other[key], (dict, list)):
|
| 571 |
+
if not value._internal_compare(other[key]):
|
| 572 |
+
return False
|
| 573 |
+
elif value != other[key]:
|
| 574 |
+
return False
|
| 575 |
+
return True
|
| 576 |
+
|
| 577 |
+
def keys(self): # pylint: disable=invalid-name
|
| 578 |
+
return self.fields.keys()
|
| 579 |
+
|
| 580 |
+
def values(self): # pylint: disable=invalid-name
|
| 581 |
+
return [self[key] for key in self]
|
| 582 |
+
|
| 583 |
+
def items(self): # pylint: disable=invalid-name
|
| 584 |
+
return [(key, self[key]) for key in self]
|
| 585 |
+
|
| 586 |
+
def get_or_create_list(self, key):
|
| 587 |
+
"""Returns a list for this key, creating if it didn't exist already."""
|
| 588 |
+
if not self.fields[key].HasField('list_value'):
|
| 589 |
+
# Clear will mark list_value modified which will indeed create a list.
|
| 590 |
+
self.fields[key].list_value.Clear()
|
| 591 |
+
return self.fields[key].list_value
|
| 592 |
+
|
| 593 |
+
def get_or_create_struct(self, key):
|
| 594 |
+
"""Returns a struct for this key, creating if it didn't exist already."""
|
| 595 |
+
if not self.fields[key].HasField('struct_value'):
|
| 596 |
+
# Clear will mark struct_value modified which will indeed create a struct.
|
| 597 |
+
self.fields[key].struct_value.Clear()
|
| 598 |
+
return self.fields[key].struct_value
|
| 599 |
+
|
| 600 |
+
def update(self, dictionary): # pylint: disable=invalid-name
|
| 601 |
+
for key, value in dictionary.items():
|
| 602 |
+
_SetStructValue(self.fields[key], value)
|
| 603 |
+
|
| 604 |
+
collections.abc.MutableMapping.register(Struct)
|
| 605 |
+
|
| 606 |
+
|
| 607 |
+
class ListValue(object):
|
| 608 |
+
"""Class for ListValue message type."""
|
| 609 |
+
|
| 610 |
+
__slots__ = ()
|
| 611 |
+
|
| 612 |
+
def __len__(self):
|
| 613 |
+
return len(self.values)
|
| 614 |
+
|
| 615 |
+
def append(self, value):
|
| 616 |
+
_SetStructValue(self.values.add(), value)
|
| 617 |
+
|
| 618 |
+
def extend(self, elem_seq):
|
| 619 |
+
for value in elem_seq:
|
| 620 |
+
self.append(value)
|
| 621 |
+
|
| 622 |
+
def __getitem__(self, index):
|
| 623 |
+
"""Retrieves item by the specified index."""
|
| 624 |
+
return _GetStructValue(self.values.__getitem__(index))
|
| 625 |
+
|
| 626 |
+
def __setitem__(self, index, value):
|
| 627 |
+
_SetStructValue(self.values.__getitem__(index), value)
|
| 628 |
+
|
| 629 |
+
def __delitem__(self, key):
|
| 630 |
+
del self.values[key]
|
| 631 |
+
|
| 632 |
+
def _internal_assign(self, elem_seq):
|
| 633 |
+
self.Clear()
|
| 634 |
+
self.extend(elem_seq)
|
| 635 |
+
|
| 636 |
+
def _internal_compare(self, other):
|
| 637 |
+
size = len(self)
|
| 638 |
+
if size != len(other):
|
| 639 |
+
return False
|
| 640 |
+
for i in range(size):
|
| 641 |
+
if isinstance(other[i], (dict, list)):
|
| 642 |
+
if not self[i]._internal_compare(other[i]):
|
| 643 |
+
return False
|
| 644 |
+
elif self[i] != other[i]:
|
| 645 |
+
return False
|
| 646 |
+
return True
|
| 647 |
+
|
| 648 |
+
def items(self):
|
| 649 |
+
for i in range(len(self)):
|
| 650 |
+
yield self[i]
|
| 651 |
+
|
| 652 |
+
def add_struct(self):
|
| 653 |
+
"""Appends and returns a struct value as the next value in the list."""
|
| 654 |
+
struct_value = self.values.add().struct_value
|
| 655 |
+
# Clear will mark struct_value modified which will indeed create a struct.
|
| 656 |
+
struct_value.Clear()
|
| 657 |
+
return struct_value
|
| 658 |
+
|
| 659 |
+
def add_list(self):
|
| 660 |
+
"""Appends and returns a list value as the next value in the list."""
|
| 661 |
+
list_value = self.values.add().list_value
|
| 662 |
+
# Clear will mark list_value modified which will indeed create a list.
|
| 663 |
+
list_value.Clear()
|
| 664 |
+
return list_value
|
| 665 |
+
|
| 666 |
+
collections.abc.MutableSequence.register(ListValue)
|
| 667 |
+
|
| 668 |
+
|
| 669 |
+
# LINT.IfChange(wktbases)
|
| 670 |
+
WKTBASES = {
|
| 671 |
+
'google.protobuf.Any': Any,
|
| 672 |
+
'google.protobuf.Duration': Duration,
|
| 673 |
+
'google.protobuf.FieldMask': FieldMask,
|
| 674 |
+
'google.protobuf.ListValue': ListValue,
|
| 675 |
+
'google.protobuf.Struct': Struct,
|
| 676 |
+
'google.protobuf.Timestamp': Timestamp,
|
| 677 |
+
}
|
| 678 |
+
# LINT.ThenChange(//depot/google.protobuf/compiler/python/pyi_generator.cc:wktbases)
|