Spaces:
Running
Running
| # Copyright 2024 Google LLC | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| """Callable responsible for running Inference on provided patches.""" | |
| import functools | |
| from typing import Any, Mapping | |
| from huggingface_hub import from_pretrained_keras | |
| from ez_wsi_dicomweb import credential_factory | |
| from ez_wsi_dicomweb import dicom_slide | |
| from ez_wsi_dicomweb import patch_embedding | |
| from ez_wsi_dicomweb import dicom_web_interface | |
| from ez_wsi_dicomweb import patch_embedding_endpoints | |
| from ez_wsi_dicomweb.ml_toolkit import dicom_path | |
| import numpy as np | |
| import tensorflow as tf | |
| from data_models import embedding_response | |
| from data_models import embedding_request | |
| from data_models import embedding_converter | |
| #from huggingface_hub import hf_hub_download | |
| from huggingface_hub import snapshot_download | |
| def _load_huggingface_model() -> tf.keras.Model: | |
| snapshot_download("google/path-foundation", local_dir="./model") | |
| return tf.keras.layers.TFSMLayer('./model', call_endpoint='serving_default') | |
| #return from_pretrained_keras("./model", compile=False) | |
| def _endpoint_model(ml_model: tf.keras.Model, image: np.ndarray) -> np.ndarray: | |
| """Function ez-wsi will use to run local ML model.""" | |
| result = ml_model.signatures['serving_default']( | |
| tf.cast(tf.constant(image), tf.float32) | |
| ) | |
| return result['output_0'].numpy() | |
| # _ENDPOINT_MODEL = functools.partial(_endpoint_model, _load_huggingface_model()) | |
| class PetePredictor: | |
| """Callable responsible for generating embeddings.""" | |
| def predict( | |
| self, | |
| prediction_input: Mapping[str, Any], | |
| ) -> Mapping[str, Any]: | |
| """Runs inference on provided patches. | |
| Args: | |
| prediction_input: JSON formatted input for embedding prediction. | |
| model: ModelRunner to handle model step. | |
| Returns: | |
| JSON formatted output. | |
| Raises: | |
| ERROR_LOADING_DICOM: If the provided patches are not concated. | |
| """ | |
| embedding_json_converter = embedding_converter.EmbeddingConverterV2() | |
| request = embedding_json_converter.json_to_embedding_request(prediction_input) | |
| endpoint = patch_embedding_endpoints.LocalEndpoint(_ENDPOINT_MODEL) | |
| embedding_results = [] | |
| for instance in request.instances: | |
| patches = [] | |
| if not isinstance(instance, embedding_request.DicomImageV2): | |
| raise ValueError('unsupported') | |
| token = instance.bearer_token | |
| if token: | |
| cf = credential_factory.TokenPassthroughCredentialFactory(token) | |
| else: | |
| cf = credential_factory.NoAuthCredentialsFactory() | |
| dwi = dicom_web_interface.DicomWebInterface(cf) | |
| path = dicom_path.FromString(instance.series_path) | |
| ds = dicom_slide.DicomSlide(dwi=dwi, path=path) | |
| level = ds.get_instance_level(instance.instance_uids[0]) | |
| for coor in instance.patch_coordinates: | |
| patches.append(ds.get_patch(level, coor.x_origin, coor.y_origin, coor.width, coor.height)) | |
| patch_embeddings = [] | |
| for index, result in enumerate(patch_embedding.generate_patch_embeddings(endpoint, patches)): | |
| embedding = np.array(result.embedding) | |
| patch_embeddings.append( | |
| embedding_response.PatchEmbeddingV2( | |
| embedding_vector=embedding.tolist(), | |
| patch_coordinate=instance.patch_coordinates[index], | |
| )) | |
| embedding_results.append( | |
| embedding_response.embedding_instance_response_v2(patch_embeddings) | |
| ) | |
| return embedding_converter.embedding_response_v2_to_json(embedding_results) | |