Spaces:
Running
Running
| # Copyright 2024 Google LLC | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| """Converts Embedding Requests and Responses to json and vice versa.""" | |
| import dataclasses | |
| import json | |
| from typing import Any, List, Mapping, Sequence | |
| from ez_wsi_dicomweb import patch_embedding_endpoints | |
| import pete_errors | |
| from data_models import embedding_request | |
| from data_models import embedding_response | |
| from data_models import patch_coordinate as patch_coordinate_module | |
| _EndpointJsonKeys = patch_embedding_endpoints.EndpointJsonKeys | |
| class ValidationError(Exception): | |
| pass | |
| class _InvalidCoordinateError(Exception): | |
| pass | |
| class _InstanceUIDMetadataError(Exception): | |
| pass | |
| def validate_int(val: Any) -> int: | |
| if isinstance(val, float): | |
| cast_val = int(val) | |
| if cast_val != val: | |
| raise ValidationError('coordinate value is not int') | |
| val = cast_val | |
| elif not isinstance(val, int): | |
| raise ValidationError('coordinate value is not int') | |
| return val | |
| def _get_patch_coord(patch_coordinates: Sequence[Mapping[str, Any]]): | |
| """Returns patch coodianates.""" | |
| result = [] | |
| if not isinstance(patch_coordinates, list): | |
| raise _InvalidCoordinateError('patch_coordinates is not list') | |
| for patch_coordinate in patch_coordinates: | |
| try: | |
| pc = patch_coordinate_module.create_patch_coordinate(**patch_coordinate) | |
| except TypeError as exp: | |
| if not isinstance(patch_coordinate, dict): | |
| raise _InvalidCoordinateError('Patch coordinate is not dict.') from exp | |
| keys = ', '.join( | |
| list( | |
| dataclasses.asdict( | |
| patch_coordinate_module.create_patch_coordinate(0, 0) | |
| ) | |
| ) | |
| ) | |
| raise _InvalidCoordinateError( | |
| f'Patch coordinate dict has invalid keys; expecting: {keys}' | |
| ) from exp | |
| try: | |
| validate_int(pc.x_origin) | |
| validate_int(pc.y_origin) | |
| validate_int(pc.width) | |
| validate_int(pc.height) | |
| except ValidationError as exp: | |
| raise _InvalidCoordinateError( | |
| f'Invalid patch coordinate; x_origin: {pc.x_origin}, y_origin:' | |
| f' {pc.y_origin}, width: {pc.width}, height: {pc.height}' | |
| ) from exp | |
| result.append(pc) | |
| if not result: | |
| raise _InvalidCoordinateError('empty patch_coordinates') | |
| return result | |
| def embedding_response_v1_to_json( | |
| response: embedding_response.EmbeddingResponseV1, | |
| ) -> Mapping[str, Any]: | |
| """Loads the model artifact. | |
| Args: | |
| response: Structed EmbeddingResponse object. | |
| Returns: | |
| The value of the JSON payload to return in the API. | |
| """ | |
| json_response = dataclasses.asdict(response) | |
| if response.error_response: | |
| json_response['error_response'][ | |
| 'error_code' | |
| ] = response.error_response.error_code.value | |
| return {_EndpointJsonKeys.PREDICTIONS: json_response} | |
| def embedding_response_v2_to_json( | |
| json_response: Sequence[Mapping[str, Any]], | |
| ) -> Mapping[str, Any]: | |
| return {_EndpointJsonKeys.PREDICTIONS: json_response} | |
| def validate_str_list(val: Any) -> List[str]: | |
| if not isinstance(val, List): | |
| raise ValidationError('not list') | |
| for v in val: | |
| if not isinstance(v, str) or not v: | |
| raise ValidationError('list contains invalid value') | |
| return val | |
| def _validate_instance_uids_not_empty_str_list(val: Any) -> List[str]: | |
| try: | |
| val = validate_str_list(val) | |
| except ValidationError as exp: | |
| raise _InstanceUIDMetadataError() from exp | |
| if not val: | |
| raise _InstanceUIDMetadataError('list is empty') | |
| return val | |
| def validate_str_key_dict(val: Any) -> Mapping[str, Any]: | |
| if not isinstance(val, dict): | |
| raise ValidationError('not a dict') | |
| if val: | |
| for k in val: | |
| if not isinstance(k, str) or not k: | |
| raise ValidationError('dict contains invalid value') | |
| return val | |
| def validate_str(val: Any) -> str: | |
| if not isinstance(val, str): | |
| raise ValidationError('not string') | |
| return val | |
| def _validate_not_empty_str(val: Any) -> str: | |
| if not isinstance(val, str) or not val: | |
| raise ValidationError('not string or empty') | |
| return val | |
| def _generate_instance_metadata_error_string( | |
| metadata: Mapping[str, Any], *keys: str | |
| ) -> str: | |
| """returns instance metadata as a error string.""" | |
| result = {} | |
| for key in keys: | |
| if key not in metadata: | |
| continue | |
| if key == _EndpointJsonKeys.EXTENSIONS: | |
| value = metadata[key] | |
| if isinstance(value, Mapping): | |
| value = dict(value) | |
| # Strip ez_wsi_state from output. | |
| # Not contributing to validation errors here and may be very large. | |
| if _EndpointJsonKeys.EZ_WSI_STATE in value: | |
| del value[_EndpointJsonKeys.EZ_WSI_STATE] | |
| result[key] = value | |
| continue | |
| elif key == _EndpointJsonKeys.BEARER_TOKEN: | |
| value = metadata[key] | |
| # If bearer token is present, and defined strip | |
| if isinstance(value, str) and value: | |
| result[key] = 'PRESENT' | |
| continue | |
| # otherwise just associate key and value. | |
| result[key] = metadata[key] | |
| return json.dumps(result, sort_keys=True) | |
| def _validate_instance_list(json_metadata: Mapping[str, Any]) -> List[Any]: | |
| val = json_metadata.get(_EndpointJsonKeys.INSTANCES) | |
| if isinstance(val, list): | |
| return val | |
| raise pete_errors.InvalidRequestFieldError( | |
| 'Invalid input, missing expected' | |
| f' key: {_EndpointJsonKeys.INSTANCES} and associated list of values.' | |
| ) | |
| class EmbeddingConverterV1: | |
| """Class containing methods for transforming embedding request and responses.""" | |
| def json_to_embedding_request( | |
| self, json_metadata: Mapping[str, Any] | |
| ) -> embedding_request.EmbeddingRequestV1: | |
| """Converts json to embedding request. | |
| Args: | |
| json_metadata: The value of the JSON payload provided to the API. | |
| Returns: | |
| Structured EmbeddingRequest object. | |
| Raises: | |
| InvalidRequestFieldError: If the provided fields are invalid. | |
| """ | |
| instances = [] | |
| try: | |
| model_params = json_metadata[_EndpointJsonKeys.PARAMETERS] | |
| try: | |
| parameters = embedding_request.EmbeddingParameters( | |
| model_size=_validate_not_empty_str( | |
| model_params.get(_EndpointJsonKeys.MODEL_SIZE) | |
| ), | |
| model_kind=_validate_not_empty_str( | |
| model_params.get(_EndpointJsonKeys.MODEL_KIND) | |
| ), | |
| ) | |
| except ValidationError as exp: | |
| raise pete_errors.InvalidRequestFieldError( | |
| 'Invalid model size and/or kind parameters.' | |
| ) from exp | |
| for instance in _validate_instance_list(json_metadata): | |
| ez_wsi_state = instance.get(_EndpointJsonKeys.EZ_WSI_STATE, {}) | |
| try: | |
| ez_wsi_state = validate_str_key_dict(ez_wsi_state) | |
| except ValidationError: | |
| try: | |
| ez_wsi_state = validate_str(ez_wsi_state) | |
| except ValidationError as exp: | |
| raise pete_errors.InvalidRequestFieldError( | |
| 'Invalid EZ-WSI state metadata.' | |
| ) from exp | |
| try: | |
| instances.append( | |
| embedding_request.EmbeddingInstanceV1( | |
| dicom_web_store_url=_validate_not_empty_str( | |
| instance.get(_EndpointJsonKeys.DICOM_WEB_STORE_URL) | |
| ), | |
| dicom_study_uid=_validate_not_empty_str( | |
| instance.get(_EndpointJsonKeys.DICOM_STUDY_UID) | |
| ), | |
| dicom_series_uid=_validate_not_empty_str( | |
| instance.get(_EndpointJsonKeys.DICOM_SERIES_UID) | |
| ), | |
| bearer_token=_validate_not_empty_str( | |
| instance.get(_EndpointJsonKeys.BEARER_TOKEN) | |
| ), | |
| ez_wsi_state=ez_wsi_state, | |
| instance_uids=_validate_instance_uids_not_empty_str_list( | |
| instance.get(_EndpointJsonKeys.INSTANCE_UIDS) | |
| ), | |
| patch_coordinates=_get_patch_coord( | |
| instance.get(_EndpointJsonKeys.PATCH_COORDINATES) | |
| ), | |
| ) | |
| ) | |
| except ValidationError as exp: | |
| instance_error_msg = _generate_instance_metadata_error_string( | |
| instance, | |
| _EndpointJsonKeys.DICOM_WEB_STORE_URL, | |
| _EndpointJsonKeys.DICOM_STUDY_UID, | |
| _EndpointJsonKeys.DICOM_SERIES_UID, | |
| _EndpointJsonKeys.BEARER_TOKEN, | |
| _EndpointJsonKeys.INSTANCE_UIDS, | |
| ) | |
| raise pete_errors.InvalidRequestFieldError( | |
| f'Invalid instance; {instance_error_msg}' | |
| ) from exp | |
| except _InstanceUIDMetadataError as exp: | |
| instance_error_msg = _generate_instance_metadata_error_string( | |
| instance, | |
| _EndpointJsonKeys.PATCH_COORDINATES, | |
| ) | |
| raise pete_errors.InvalidRequestFieldError( | |
| f'Invalid DICOM SOP Instance UID metadata; {instance_error_msg}' | |
| ) from exp | |
| except _InvalidCoordinateError as exp: | |
| raise pete_errors.InvalidRequestFieldError( | |
| f'Invalid patch coordinate; {exp}' | |
| ) from exp | |
| except (TypeError, ValueError, KeyError) as exp: | |
| raise pete_errors.InvalidRequestFieldError( | |
| f'Invalid input: {json.dumps(json_metadata)}' | |
| ) from exp | |
| return embedding_request.EmbeddingRequestV1( | |
| parameters=parameters, instances=instances | |
| ) | |
| class EmbeddingConverterV2: | |
| """Class containing methods for transforming embedding request and responses.""" | |
| def json_to_embedding_request( | |
| self, json_metadata: Mapping[str, Any] | |
| ) -> embedding_request.EmbeddingRequestV2: | |
| """Converts json to embedding request. | |
| Args: | |
| json_metadata: The value of the JSON payload provided to the API. | |
| Returns: | |
| Structured EmbeddingRequest object. | |
| Raises: | |
| InvalidRequestFieldError: If the provided fields are invalid. | |
| """ | |
| instances = [] | |
| for instance in _validate_instance_list(json_metadata): | |
| try: | |
| patch_coordinates = _get_patch_coord( | |
| instance.get(_EndpointJsonKeys.PATCH_COORDINATES) | |
| ) | |
| except _InvalidCoordinateError as exp: | |
| instance_error_msg = _generate_instance_metadata_error_string( | |
| instance, | |
| _EndpointJsonKeys.PATCH_COORDINATES, | |
| ) | |
| raise pete_errors.InvalidRequestFieldError( | |
| f'Invalid patch coordinate; {exp}; {instance_error_msg}' | |
| ) from exp | |
| if _EndpointJsonKeys.DICOM_PATH in instance: | |
| try: | |
| dicom_path = validate_str_key_dict( | |
| instance.get(_EndpointJsonKeys.DICOM_PATH) | |
| ) | |
| except ValidationError as exp: | |
| raise pete_errors.InvalidRequestFieldError( | |
| 'Invalid DICOM path.' | |
| ) from exp | |
| try: | |
| instances.append( | |
| embedding_request.DicomImageV2( | |
| series_path=_validate_not_empty_str( | |
| dicom_path.get(_EndpointJsonKeys.SERIES_PATH) | |
| ), | |
| bearer_token=validate_str( | |
| instance.get( | |
| _EndpointJsonKeys.BEARER_TOKEN, | |
| '', | |
| ) | |
| ), | |
| extensions=validate_str_key_dict( | |
| instance.get( | |
| _EndpointJsonKeys.EXTENSIONS, | |
| {}, | |
| ) | |
| ), | |
| instance_uids=_validate_instance_uids_not_empty_str_list( | |
| dicom_path.get(_EndpointJsonKeys.INSTANCE_UIDS) | |
| ), | |
| patch_coordinates=patch_coordinates, | |
| ) | |
| ) | |
| except _InstanceUIDMetadataError as exp: | |
| error_msg = _generate_instance_metadata_error_string( | |
| instance, | |
| _EndpointJsonKeys.SERIES_PATH, | |
| _EndpointJsonKeys.BEARER_TOKEN, | |
| _EndpointJsonKeys.EXTENSIONS, | |
| _EndpointJsonKeys.INSTANCE_UIDS, | |
| ) | |
| raise pete_errors.InvalidRequestFieldError( | |
| f'Invalid DICOM SOP Instance UID metadata; {error_msg}' | |
| ) from exp | |
| except ValidationError as exp: | |
| error_msg = _generate_instance_metadata_error_string( | |
| instance, | |
| _EndpointJsonKeys.SERIES_PATH, | |
| _EndpointJsonKeys.BEARER_TOKEN, | |
| _EndpointJsonKeys.EXTENSIONS, | |
| _EndpointJsonKeys.INSTANCE_UIDS, | |
| ) | |
| raise pete_errors.InvalidRequestFieldError( | |
| f'DICOM instance JSON formatting is invalid; {error_msg}' | |
| ) from exp | |
| elif _EndpointJsonKeys.IMAGE_FILE_URI in instance: | |
| try: | |
| instances.append( | |
| embedding_request.GcsImageV2( | |
| image_file_uri=_validate_not_empty_str( | |
| instance.get(_EndpointJsonKeys.IMAGE_FILE_URI) | |
| ), | |
| bearer_token=validate_str( | |
| instance.get( | |
| _EndpointJsonKeys.BEARER_TOKEN, | |
| '', | |
| ) | |
| ), | |
| extensions=validate_str_key_dict( | |
| instance.get( | |
| _EndpointJsonKeys.EXTENSIONS, | |
| {}, | |
| ) | |
| ), | |
| patch_coordinates=patch_coordinates, | |
| ) | |
| ) | |
| except ValidationError as exp: | |
| error_msg = _generate_instance_metadata_error_string( | |
| instance, | |
| _EndpointJsonKeys.IMAGE_FILE_URI, | |
| _EndpointJsonKeys.BEARER_TOKEN, | |
| _EndpointJsonKeys.EXTENSIONS, | |
| ) | |
| raise pete_errors.InvalidRequestFieldError( | |
| 'Google Cloud Storage instance JSON formatting is invalid;' | |
| f' {error_msg}' | |
| ) from exp | |
| elif _EndpointJsonKeys.RAW_IMAGE_BYTES in instance: | |
| try: | |
| instances.append( | |
| embedding_request.EmbeddedImageV2( | |
| image_bytes=_validate_not_empty_str( | |
| instance.get(_EndpointJsonKeys.RAW_IMAGE_BYTES) | |
| ), | |
| extensions=validate_str_key_dict( | |
| instance.get( | |
| _EndpointJsonKeys.EXTENSIONS, | |
| {}, | |
| ) | |
| ), | |
| patch_coordinates=patch_coordinates, | |
| ) | |
| ) | |
| except ValidationError as exp: | |
| error_msg = _generate_instance_metadata_error_string( | |
| instance, | |
| _EndpointJsonKeys.IMAGE_FILE_URI, | |
| _EndpointJsonKeys.BEARER_TOKEN, | |
| _EndpointJsonKeys.EXTENSIONS, | |
| ) | |
| raise pete_errors.InvalidRequestFieldError( | |
| 'Embedded image instance JSON formatting is invalid; ' | |
| f' {error_msg}' | |
| ) from exp | |
| else: | |
| raise pete_errors.InvalidRequestFieldError('unidentified type') | |
| return embedding_request.EmbeddingRequestV2(instances) | |